Skip to content

Commit f64f2d0

Browse files
committed
update
1 parent 6b53b85 commit f64f2d0

File tree

2 files changed

+36
-27
lines changed

2 files changed

+36
-27
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_allegro.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ..upsampling import Upsample2D
3030

3131

32-
class AllegroTemporalConvBlock(nn.Module):
32+
class AllegroTemporalConvLayer(nn.Module):
3333
r"""
3434
Temporal convolutional layer that can be used for video (sequence of images) input. Code adapted from:
3535
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
@@ -40,6 +40,7 @@ def __init__(
4040
in_dim: int,
4141
out_dim: Optional[int] = None,
4242
dropout: float = 0.0,
43+
norm_num_groups: int = 32,
4344
up_sample: bool = False,
4445
down_sample: bool = False,
4546
stride: int = 1,
@@ -55,44 +56,40 @@ def __init__(
5556

5657
if down_sample:
5758
self.conv1 = nn.Sequential(
58-
nn.GroupNorm(32, in_dim),
59+
nn.GroupNorm(norm_num_groups, in_dim),
5960
nn.SiLU(),
6061
nn.Conv3d(in_dim, out_dim, (2, stride, stride), stride=(2, 1, 1), padding=(0, pad_h, pad_w)),
6162
)
6263
elif up_sample:
6364
self.conv1 = nn.Sequential(
64-
nn.GroupNorm(32, in_dim),
65+
nn.GroupNorm(norm_num_groups, in_dim),
6566
nn.SiLU(),
6667
nn.Conv3d(in_dim, out_dim * 2, (1, stride, stride), padding=(0, pad_h, pad_w)),
6768
)
6869
else:
6970
self.conv1 = nn.Sequential(
70-
nn.GroupNorm(32, in_dim),
71+
nn.GroupNorm(norm_num_groups, in_dim),
7172
nn.SiLU(),
7273
nn.Conv3d(in_dim, out_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)),
7374
)
7475
self.conv2 = nn.Sequential(
75-
nn.GroupNorm(32, out_dim),
76+
nn.GroupNorm(norm_num_groups, out_dim),
7677
nn.SiLU(),
7778
nn.Dropout(dropout),
7879
nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)),
7980
)
8081
self.conv3 = nn.Sequential(
81-
nn.GroupNorm(32, out_dim),
82+
nn.GroupNorm(norm_num_groups, out_dim),
8283
nn.SiLU(),
8384
nn.Dropout(dropout),
8485
nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)),
8586
)
8687
self.conv4 = nn.Sequential(
87-
nn.GroupNorm(32, out_dim),
88+
nn.GroupNorm(norm_num_groups, out_dim),
8889
nn.SiLU(),
8990
nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)),
9091
)
9192

92-
# zero out the last layer params, so the conv block is identity
93-
nn.init.zeros_(self.conv4[-1].weight)
94-
nn.init.zeros_(self.conv4[-1].bias)
95-
9693
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
9794
identity = hidden_states
9895

@@ -169,19 +166,20 @@ def __init__(
169166
)
170167
)
171168
temp_convs.append(
172-
AllegroTemporalConvBlock(
169+
AllegroTemporalConvLayer(
173170
out_channels,
174171
out_channels,
175172
dropout=0.1,
173+
norm_num_groups=resnet_groups,
176174
)
177175
)
178176

179177
self.resnets = nn.ModuleList(resnets)
180178
self.temp_convs = nn.ModuleList(temp_convs)
181179

182180
if add_temp_downsample:
183-
self.temp_convs_down = AllegroTemporalConvBlock(
184-
out_channels, out_channels, dropout=0.1, down_sample=True, stride=3
181+
self.temp_convs_down = AllegroTemporalConvLayer(
182+
out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, down_sample=True, stride=3
185183
)
186184
self.add_temp_downsample = add_temp_downsample
187185

@@ -258,10 +256,11 @@ def __init__(
258256
)
259257
)
260258
temp_convs.append(
261-
AllegroTemporalConvBlock(
259+
AllegroTemporalConvLayer(
262260
out_channels,
263261
out_channels,
264262
dropout=0.1,
263+
norm_num_groups=resnet_groups,
265264
)
266265
)
267266

@@ -270,8 +269,8 @@ def __init__(
270269

271270
self.add_temp_upsample = add_temp_upsample
272271
if add_temp_upsample:
273-
self.temp_conv_up = AllegroTemporalConvBlock(
274-
out_channels, out_channels, dropout=0.1, up_sample=True, stride=3
272+
self.temp_conv_up = AllegroTemporalConvLayer(
273+
out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, up_sample=True, stride=3
275274
)
276275

277276
if self.add_upsample:
@@ -336,10 +335,11 @@ def __init__(
336335
)
337336
]
338337
temp_convs = [
339-
AllegroTemporalConvBlock(
338+
AllegroTemporalConvLayer(
340339
in_channels,
341340
in_channels,
342341
dropout=0.1,
342+
norm_num_groups=resnet_groups,
343343
)
344344
]
345345
attentions = []
@@ -383,10 +383,11 @@ def __init__(
383383
)
384384

385385
temp_convs.append(
386-
AllegroTemporalConvBlock(
386+
AllegroTemporalConvLayer(
387387
in_channels,
388388
in_channels,
389389
dropout=0.1,
390+
norm_num_groups=resnet_groups,
390391
)
391392
)
392393

@@ -513,6 +514,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
513514
sample = sample + residual
514515

515516
if self.gradient_checkpointing:
517+
516518
def create_custom_forward(module):
517519
def custom_forward(*inputs):
518520
return module(*inputs)
@@ -655,24 +657,19 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
655657
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
656658

657659
if self.gradient_checkpointing:
660+
658661
def create_custom_forward(module):
659662
def custom_forward(*inputs):
660663
return module(*inputs)
661664

662665
return custom_forward
663666

664667
# Mid block
665-
sample = torch.utils.checkpoint.checkpoint(
666-
create_custom_forward(self.mid_block),
667-
sample
668-
)
668+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
669669

670670
# Up blocks
671671
for up_block in self.up_blocks:
672-
sample = torch.utils.checkpoint.checkpoint(
673-
create_custom_forward(up_block),
674-
sample
675-
)
672+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
676673

677674
else:
678675
# Mid block

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,18 @@ def _prepare_rotary_positional_embeddings(
645645

646646
return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w)
647647

648+
@property
649+
def guidance_scale(self):
650+
return self._guidance_scale
651+
652+
@property
653+
def num_timesteps(self):
654+
return self._num_timesteps
655+
656+
@property
657+
def interrupt(self):
658+
return self._interrupt
659+
648660
@torch.no_grad()
649661
@replace_example_docstring(EXAMPLE_DOC_STRING)
650662
def __call__(

0 commit comments

Comments
 (0)