42
42
)
43
43
44
44
45
+ class ResnetBlockCondNorm2D (nn .Module ):
46
+ r"""
47
+ A Resnet block that use normalization layer that incorporate conditioning information.
48
+
49
+ Parameters:
50
+ in_channels (`int`): The number of channels in the input.
51
+ out_channels (`int`, *optional*, default to be `None`):
52
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
53
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
54
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
55
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
56
+ groups_out (`int`, *optional*, default to None):
57
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
58
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
59
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
60
+ time_embedding_norm (`str`, *optional*, default to `"ada_group"` ):
61
+ The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial".
62
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
63
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
64
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
65
+ use_in_shortcut (`bool`, *optional*, default to `True`):
66
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
67
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
68
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
69
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
70
+ `conv_shortcut` output.
71
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
72
+ If None, same as `out_channels`.
73
+ """
74
+
75
+ def __init__ (
76
+ self ,
77
+ * ,
78
+ in_channels : int ,
79
+ out_channels : Optional [int ] = None ,
80
+ conv_shortcut : bool = False ,
81
+ dropout : float = 0.0 ,
82
+ temb_channels : int = 512 ,
83
+ groups : int = 32 ,
84
+ groups_out : Optional [int ] = None ,
85
+ eps : float = 1e-6 ,
86
+ non_linearity : str = "swish" ,
87
+ time_embedding_norm : str = "ada_group" , # ada_group, spatial
88
+ output_scale_factor : float = 1.0 ,
89
+ use_in_shortcut : Optional [bool ] = None ,
90
+ up : bool = False ,
91
+ down : bool = False ,
92
+ conv_shortcut_bias : bool = True ,
93
+ conv_2d_out_channels : Optional [int ] = None ,
94
+ ):
95
+ super ().__init__ ()
96
+ self .in_channels = in_channels
97
+ out_channels = in_channels if out_channels is None else out_channels
98
+ self .out_channels = out_channels
99
+ self .use_conv_shortcut = conv_shortcut
100
+ self .up = up
101
+ self .down = down
102
+ self .output_scale_factor = output_scale_factor
103
+ self .time_embedding_norm = time_embedding_norm
104
+
105
+ conv_cls = nn .Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
106
+
107
+ if groups_out is None :
108
+ groups_out = groups
109
+
110
+ if self .time_embedding_norm == "ada_group" : # ada_group
111
+ self .norm1 = AdaGroupNorm (temb_channels , in_channels , groups , eps = eps )
112
+ elif self .time_embedding_norm == "spatial" :
113
+ self .norm1 = SpatialNorm (in_channels , temb_channels )
114
+ else :
115
+ raise ValueError (f" unsupported time_embedding_norm: { self .time_embedding_norm } " )
116
+
117
+ self .conv1 = conv_cls (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
118
+
119
+ if self .time_embedding_norm == "ada_group" : # ada_group
120
+ self .norm2 = AdaGroupNorm (temb_channels , out_channels , groups_out , eps = eps )
121
+ elif self .time_embedding_norm == "spatial" : # spatial
122
+ self .norm2 = SpatialNorm (out_channels , temb_channels )
123
+ else :
124
+ raise ValueError (f" unsupported time_embedding_norm: { self .time_embedding_norm } " )
125
+
126
+ self .dropout = torch .nn .Dropout (dropout )
127
+
128
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
129
+ self .conv2 = conv_cls (out_channels , conv_2d_out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
130
+
131
+ self .nonlinearity = get_activation (non_linearity )
132
+
133
+ self .upsample = self .downsample = None
134
+ if self .up :
135
+ self .upsample = Upsample2D (in_channels , use_conv = False )
136
+ elif self .down :
137
+ self .downsample = Downsample2D (in_channels , use_conv = False , padding = 1 , name = "op" )
138
+
139
+ self .use_in_shortcut = self .in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
140
+
141
+ self .conv_shortcut = None
142
+ if self .use_in_shortcut :
143
+ self .conv_shortcut = conv_cls (
144
+ in_channels ,
145
+ conv_2d_out_channels ,
146
+ kernel_size = 1 ,
147
+ stride = 1 ,
148
+ padding = 0 ,
149
+ bias = conv_shortcut_bias ,
150
+ )
151
+
152
+ def forward (
153
+ self ,
154
+ input_tensor : torch .FloatTensor ,
155
+ temb : torch .FloatTensor ,
156
+ scale : float = 1.0 ,
157
+ ) -> torch .FloatTensor :
158
+ hidden_states = input_tensor
159
+
160
+ hidden_states = self .norm1 (hidden_states , temb )
161
+
162
+ hidden_states = self .nonlinearity (hidden_states )
163
+
164
+ if self .upsample is not None :
165
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
166
+ if hidden_states .shape [0 ] >= 64 :
167
+ input_tensor = input_tensor .contiguous ()
168
+ hidden_states = hidden_states .contiguous ()
169
+ input_tensor = self .upsample (input_tensor , scale = scale )
170
+ hidden_states = self .upsample (hidden_states , scale = scale )
171
+
172
+ elif self .downsample is not None :
173
+ input_tensor = self .downsample (input_tensor , scale = scale )
174
+ hidden_states = self .downsample (hidden_states , scale = scale )
175
+
176
+ hidden_states = self .conv1 (hidden_states , scale ) if not USE_PEFT_BACKEND else self .conv1 (hidden_states )
177
+
178
+ hidden_states = self .norm2 (hidden_states , temb )
179
+
180
+ hidden_states = self .nonlinearity (hidden_states )
181
+
182
+ hidden_states = self .dropout (hidden_states )
183
+ hidden_states = self .conv2 (hidden_states , scale ) if not USE_PEFT_BACKEND else self .conv2 (hidden_states )
184
+
185
+ if self .conv_shortcut is not None :
186
+ input_tensor = (
187
+ self .conv_shortcut (input_tensor , scale ) if not USE_PEFT_BACKEND else self .conv_shortcut (input_tensor )
188
+ )
189
+
190
+ output_tensor = (input_tensor + hidden_states ) / self .output_scale_factor
191
+
192
+ return output_tensor
193
+
194
+
45
195
class ResnetBlock2D (nn .Module ):
46
196
r"""
47
197
A Resnet block.
@@ -58,8 +208,8 @@ class ResnetBlock2D(nn.Module):
58
208
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
59
209
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
60
210
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
61
- By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
62
- "ada_group" for a stronger conditioning with scale and shift.
211
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift"
212
+ for a stronger conditioning with scale and shift.
63
213
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
64
214
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
65
215
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
@@ -87,7 +237,7 @@ def __init__(
87
237
eps : float = 1e-6 ,
88
238
non_linearity : str = "swish" ,
89
239
skip_time_act : bool = False ,
90
- time_embedding_norm : str = "default" , # default, scale_shift, ada_group, spatial
240
+ time_embedding_norm : str = "default" , # default, scale_shift,
91
241
kernel : Optional [torch .FloatTensor ] = None ,
92
242
output_scale_factor : float = 1.0 ,
93
243
use_in_shortcut : Optional [bool ] = None ,
@@ -97,7 +247,15 @@ def __init__(
97
247
conv_2d_out_channels : Optional [int ] = None ,
98
248
):
99
249
super ().__init__ ()
100
- self .pre_norm = pre_norm
250
+ if time_embedding_norm == "ada_group" :
251
+ raise ValueError (
252
+ "This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead" ,
253
+ )
254
+ if time_embedding_norm == "spatial" :
255
+ raise ValueError (
256
+ "This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead" ,
257
+ )
258
+
101
259
self .pre_norm = True
102
260
self .in_channels = in_channels
103
261
out_channels = in_channels if out_channels is None else out_channels
@@ -115,12 +273,7 @@ def __init__(
115
273
if groups_out is None :
116
274
groups_out = groups
117
275
118
- if self .time_embedding_norm == "ada_group" :
119
- self .norm1 = AdaGroupNorm (temb_channels , in_channels , groups , eps = eps )
120
- elif self .time_embedding_norm == "spatial" :
121
- self .norm1 = SpatialNorm (in_channels , temb_channels )
122
- else :
123
- self .norm1 = torch .nn .GroupNorm (num_groups = groups , num_channels = in_channels , eps = eps , affine = True )
276
+ self .norm1 = torch .nn .GroupNorm (num_groups = groups , num_channels = in_channels , eps = eps , affine = True )
124
277
125
278
self .conv1 = conv_cls (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
126
279
@@ -129,19 +282,12 @@ def __init__(
129
282
self .time_emb_proj = linear_cls (temb_channels , out_channels )
130
283
elif self .time_embedding_norm == "scale_shift" :
131
284
self .time_emb_proj = linear_cls (temb_channels , 2 * out_channels )
132
- elif self .time_embedding_norm == "ada_group" or self .time_embedding_norm == "spatial" :
133
- self .time_emb_proj = None
134
285
else :
135
286
raise ValueError (f"unknown time_embedding_norm : { self .time_embedding_norm } " )
136
287
else :
137
288
self .time_emb_proj = None
138
289
139
- if self .time_embedding_norm == "ada_group" :
140
- self .norm2 = AdaGroupNorm (temb_channels , out_channels , groups_out , eps = eps )
141
- elif self .time_embedding_norm == "spatial" :
142
- self .norm2 = SpatialNorm (out_channels , temb_channels )
143
- else :
144
- self .norm2 = torch .nn .GroupNorm (num_groups = groups_out , num_channels = out_channels , eps = eps , affine = True )
290
+ self .norm2 = torch .nn .GroupNorm (num_groups = groups_out , num_channels = out_channels , eps = eps , affine = True )
145
291
146
292
self .dropout = torch .nn .Dropout (dropout )
147
293
conv_2d_out_channels = conv_2d_out_channels or out_channels
@@ -188,11 +334,7 @@ def forward(
188
334
) -> torch .FloatTensor :
189
335
hidden_states = input_tensor
190
336
191
- if self .time_embedding_norm == "ada_group" or self .time_embedding_norm == "spatial" :
192
- hidden_states = self .norm1 (hidden_states , temb )
193
- else :
194
- hidden_states = self .norm1 (hidden_states )
195
-
337
+ hidden_states = self .norm1 (hidden_states )
196
338
hidden_states = self .nonlinearity (hidden_states )
197
339
198
340
if self .upsample is not None :
@@ -233,17 +375,20 @@ def forward(
233
375
else self .time_emb_proj (temb )[:, :, None , None ]
234
376
)
235
377
236
- if temb is not None and self .time_embedding_norm == "default" :
237
- hidden_states = hidden_states + temb
238
-
239
- if self .time_embedding_norm == "ada_group" or self .time_embedding_norm == "spatial" :
240
- hidden_states = self .norm2 (hidden_states , temb )
241
- else :
378
+ if self .time_embedding_norm == "default" :
379
+ if temb is not None :
380
+ hidden_states = hidden_states + temb
242
381
hidden_states = self .norm2 (hidden_states )
243
-
244
- if temb is not None and self .time_embedding_norm == "scale_shift" :
382
+ elif self .time_embedding_norm == "scale_shift" :
383
+ if temb is None :
384
+ raise ValueError (
385
+ f" `temb` should not be None when `time_embedding_norm` is { self .time_embedding_norm } "
386
+ )
245
387
scale , shift = torch .chunk (temb , 2 , dim = 1 )
388
+ hidden_states = self .norm2 (hidden_states )
246
389
hidden_states = hidden_states * (1 + scale ) + shift
390
+ else :
391
+ hidden_states = self .norm2 (hidden_states )
247
392
248
393
hidden_states = self .nonlinearity (hidden_states )
249
394
0 commit comments