Skip to content

Commit dd4459a

Browse files
yiyixuxuyiyixuxupatrickvonplatensayakpaul
authored
[Refactor] splitingResnetBlock2D into multiple blocks (#6166)
--------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 6313645 commit dd4459a

File tree

3 files changed

+418
-142
lines changed

3 files changed

+418
-142
lines changed

src/diffusers/models/resnet.py

Lines changed: 176 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,156 @@
4242
)
4343

4444

45+
class ResnetBlockCondNorm2D(nn.Module):
46+
r"""
47+
A Resnet block that use normalization layer that incorporate conditioning information.
48+
49+
Parameters:
50+
in_channels (`int`): The number of channels in the input.
51+
out_channels (`int`, *optional*, default to be `None`):
52+
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
53+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
54+
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
55+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
56+
groups_out (`int`, *optional*, default to None):
57+
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
58+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
59+
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
60+
time_embedding_norm (`str`, *optional*, default to `"ada_group"` ):
61+
The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial".
62+
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
63+
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
64+
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
65+
use_in_shortcut (`bool`, *optional*, default to `True`):
66+
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
67+
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
68+
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
69+
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
70+
`conv_shortcut` output.
71+
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
72+
If None, same as `out_channels`.
73+
"""
74+
75+
def __init__(
76+
self,
77+
*,
78+
in_channels: int,
79+
out_channels: Optional[int] = None,
80+
conv_shortcut: bool = False,
81+
dropout: float = 0.0,
82+
temb_channels: int = 512,
83+
groups: int = 32,
84+
groups_out: Optional[int] = None,
85+
eps: float = 1e-6,
86+
non_linearity: str = "swish",
87+
time_embedding_norm: str = "ada_group", # ada_group, spatial
88+
output_scale_factor: float = 1.0,
89+
use_in_shortcut: Optional[bool] = None,
90+
up: bool = False,
91+
down: bool = False,
92+
conv_shortcut_bias: bool = True,
93+
conv_2d_out_channels: Optional[int] = None,
94+
):
95+
super().__init__()
96+
self.in_channels = in_channels
97+
out_channels = in_channels if out_channels is None else out_channels
98+
self.out_channels = out_channels
99+
self.use_conv_shortcut = conv_shortcut
100+
self.up = up
101+
self.down = down
102+
self.output_scale_factor = output_scale_factor
103+
self.time_embedding_norm = time_embedding_norm
104+
105+
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
106+
107+
if groups_out is None:
108+
groups_out = groups
109+
110+
if self.time_embedding_norm == "ada_group": # ada_group
111+
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
112+
elif self.time_embedding_norm == "spatial":
113+
self.norm1 = SpatialNorm(in_channels, temb_channels)
114+
else:
115+
raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")
116+
117+
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
118+
119+
if self.time_embedding_norm == "ada_group": # ada_group
120+
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
121+
elif self.time_embedding_norm == "spatial": # spatial
122+
self.norm2 = SpatialNorm(out_channels, temb_channels)
123+
else:
124+
raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")
125+
126+
self.dropout = torch.nn.Dropout(dropout)
127+
128+
conv_2d_out_channels = conv_2d_out_channels or out_channels
129+
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
130+
131+
self.nonlinearity = get_activation(non_linearity)
132+
133+
self.upsample = self.downsample = None
134+
if self.up:
135+
self.upsample = Upsample2D(in_channels, use_conv=False)
136+
elif self.down:
137+
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
138+
139+
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
140+
141+
self.conv_shortcut = None
142+
if self.use_in_shortcut:
143+
self.conv_shortcut = conv_cls(
144+
in_channels,
145+
conv_2d_out_channels,
146+
kernel_size=1,
147+
stride=1,
148+
padding=0,
149+
bias=conv_shortcut_bias,
150+
)
151+
152+
def forward(
153+
self,
154+
input_tensor: torch.FloatTensor,
155+
temb: torch.FloatTensor,
156+
scale: float = 1.0,
157+
) -> torch.FloatTensor:
158+
hidden_states = input_tensor
159+
160+
hidden_states = self.norm1(hidden_states, temb)
161+
162+
hidden_states = self.nonlinearity(hidden_states)
163+
164+
if self.upsample is not None:
165+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
166+
if hidden_states.shape[0] >= 64:
167+
input_tensor = input_tensor.contiguous()
168+
hidden_states = hidden_states.contiguous()
169+
input_tensor = self.upsample(input_tensor, scale=scale)
170+
hidden_states = self.upsample(hidden_states, scale=scale)
171+
172+
elif self.downsample is not None:
173+
input_tensor = self.downsample(input_tensor, scale=scale)
174+
hidden_states = self.downsample(hidden_states, scale=scale)
175+
176+
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
177+
178+
hidden_states = self.norm2(hidden_states, temb)
179+
180+
hidden_states = self.nonlinearity(hidden_states)
181+
182+
hidden_states = self.dropout(hidden_states)
183+
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
184+
185+
if self.conv_shortcut is not None:
186+
input_tensor = (
187+
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
188+
)
189+
190+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
191+
192+
return output_tensor
193+
194+
45195
class ResnetBlock2D(nn.Module):
46196
r"""
47197
A Resnet block.
@@ -58,8 +208,8 @@ class ResnetBlock2D(nn.Module):
58208
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
59209
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
60210
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
61-
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
62-
"ada_group" for a stronger conditioning with scale and shift.
211+
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift"
212+
for a stronger conditioning with scale and shift.
63213
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
64214
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
65215
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
@@ -87,7 +237,7 @@ def __init__(
87237
eps: float = 1e-6,
88238
non_linearity: str = "swish",
89239
skip_time_act: bool = False,
90-
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
240+
time_embedding_norm: str = "default", # default, scale_shift,
91241
kernel: Optional[torch.FloatTensor] = None,
92242
output_scale_factor: float = 1.0,
93243
use_in_shortcut: Optional[bool] = None,
@@ -97,7 +247,15 @@ def __init__(
97247
conv_2d_out_channels: Optional[int] = None,
98248
):
99249
super().__init__()
100-
self.pre_norm = pre_norm
250+
if time_embedding_norm == "ada_group":
251+
raise ValueError(
252+
"This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead",
253+
)
254+
if time_embedding_norm == "spatial":
255+
raise ValueError(
256+
"This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead",
257+
)
258+
101259
self.pre_norm = True
102260
self.in_channels = in_channels
103261
out_channels = in_channels if out_channels is None else out_channels
@@ -115,12 +273,7 @@ def __init__(
115273
if groups_out is None:
116274
groups_out = groups
117275

118-
if self.time_embedding_norm == "ada_group":
119-
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
120-
elif self.time_embedding_norm == "spatial":
121-
self.norm1 = SpatialNorm(in_channels, temb_channels)
122-
else:
123-
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
276+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
124277

125278
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
126279

@@ -129,19 +282,12 @@ def __init__(
129282
self.time_emb_proj = linear_cls(temb_channels, out_channels)
130283
elif self.time_embedding_norm == "scale_shift":
131284
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
132-
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
133-
self.time_emb_proj = None
134285
else:
135286
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
136287
else:
137288
self.time_emb_proj = None
138289

139-
if self.time_embedding_norm == "ada_group":
140-
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
141-
elif self.time_embedding_norm == "spatial":
142-
self.norm2 = SpatialNorm(out_channels, temb_channels)
143-
else:
144-
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
290+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
145291

146292
self.dropout = torch.nn.Dropout(dropout)
147293
conv_2d_out_channels = conv_2d_out_channels or out_channels
@@ -188,11 +334,7 @@ def forward(
188334
) -> torch.FloatTensor:
189335
hidden_states = input_tensor
190336

191-
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
192-
hidden_states = self.norm1(hidden_states, temb)
193-
else:
194-
hidden_states = self.norm1(hidden_states)
195-
337+
hidden_states = self.norm1(hidden_states)
196338
hidden_states = self.nonlinearity(hidden_states)
197339

198340
if self.upsample is not None:
@@ -233,17 +375,20 @@ def forward(
233375
else self.time_emb_proj(temb)[:, :, None, None]
234376
)
235377

236-
if temb is not None and self.time_embedding_norm == "default":
237-
hidden_states = hidden_states + temb
238-
239-
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
240-
hidden_states = self.norm2(hidden_states, temb)
241-
else:
378+
if self.time_embedding_norm == "default":
379+
if temb is not None:
380+
hidden_states = hidden_states + temb
242381
hidden_states = self.norm2(hidden_states)
243-
244-
if temb is not None and self.time_embedding_norm == "scale_shift":
382+
elif self.time_embedding_norm == "scale_shift":
383+
if temb is None:
384+
raise ValueError(
385+
f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
386+
)
245387
scale, shift = torch.chunk(temb, 2, dim=1)
388+
hidden_states = self.norm2(hidden_states)
246389
hidden_states = hidden_states * (1 + scale) + shift
390+
else:
391+
hidden_states = self.norm2(hidden_states)
247392

248393
hidden_states = self.nonlinearity(hidden_states)
249394

0 commit comments

Comments
 (0)