Skip to content

Commit e34fb67

Browse files
mlfarinhaMiguel Farinhahlky
authored andcommitted
Allow image resolutions multiple of 8 instead of 64 in SVD pipeline (#6646)
allow resolutions not multiple of 64 in SVD Co-authored-by: Miguel Farinha <[email protected]> Co-authored-by: hlky <[email protected]>
1 parent e99a249 commit e34fb67

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

src/diffusers/models/unets/unet_3d_blocks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,7 @@ def forward(
13751375
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
13761376
temb: Optional[torch.Tensor] = None,
13771377
image_only_indicator: Optional[torch.Tensor] = None,
1378+
upsample_size: Optional[int] = None,
13781379
) -> torch.Tensor:
13791380
for resnet in self.resnets:
13801381
# pop res hidden states
@@ -1415,7 +1416,7 @@ def custom_forward(*inputs):
14151416

14161417
if self.upsamplers is not None:
14171418
for upsampler in self.upsamplers:
1418-
hidden_states = upsampler(hidden_states)
1419+
hidden_states = upsampler(hidden_states, upsample_size)
14191420

14201421
return hidden_states
14211422

@@ -1485,6 +1486,7 @@ def forward(
14851486
temb: Optional[torch.Tensor] = None,
14861487
encoder_hidden_states: Optional[torch.Tensor] = None,
14871488
image_only_indicator: Optional[torch.Tensor] = None,
1489+
upsample_size: Optional[int] = None,
14881490
) -> torch.Tensor:
14891491
for resnet, attn in zip(self.resnets, self.attentions):
14901492
# pop res hidden states
@@ -1533,6 +1535,6 @@ def custom_forward(*inputs):
15331535

15341536
if self.upsamplers is not None:
15351537
for upsampler in self.upsamplers:
1536-
hidden_states = upsampler(hidden_states)
1538+
hidden_states = upsampler(hidden_states, upsample_size)
15371539

15381540
return hidden_states

src/diffusers/models/unets/unet_spatio_temporal_condition.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,20 @@ def forward(
382382
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
383383
returned, otherwise a `tuple` is returned where the first element is the sample tensor.
384384
"""
385+
# By default samples have to be AT least a multiple of the overall upsampling factor.
386+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
387+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
388+
# on the fly if necessary.
389+
default_overall_up_factor = 2**self.num_upsamplers
390+
391+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
392+
forward_upsample_size = False
393+
upsample_size = None
394+
395+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
396+
logger.info("Forward upsample size to force interpolation output size.")
397+
forward_upsample_size = True
398+
385399
# 1. time
386400
timesteps = timestep
387401
if not torch.is_tensor(timesteps):
@@ -457,22 +471,31 @@ def forward(
457471

458472
# 5. up
459473
for i, upsample_block in enumerate(self.up_blocks):
474+
is_final_block = i == len(self.up_blocks) - 1
475+
460476
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
461477
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
462478

479+
# if we have not reached the final block and need to forward the
480+
# upsample size, we do it here
481+
if not is_final_block and forward_upsample_size:
482+
upsample_size = down_block_res_samples[-1].shape[2:]
483+
463484
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
464485
sample = upsample_block(
465486
hidden_states=sample,
466487
temb=emb,
467488
res_hidden_states_tuple=res_samples,
468489
encoder_hidden_states=encoder_hidden_states,
490+
upsample_size=upsample_size,
469491
image_only_indicator=image_only_indicator,
470492
)
471493
else:
472494
sample = upsample_block(
473495
hidden_states=sample,
474496
temb=emb,
475497
res_hidden_states_tuple=res_samples,
498+
upsample_size=upsample_size,
476499
image_only_indicator=image_only_indicator,
477500
)
478501

0 commit comments

Comments
 (0)