Skip to content

Commit 99fb2b4

Browse files
Miguel FarinhaMiguel Farinha
authored andcommitted
allow resolutions not multiple of 64 in SVD
1 parent ac61eef commit 99fb2b4

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

src/diffusers/models/unet_3d_blocks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2231,6 +2231,7 @@ def forward(
22312231
hidden_states: torch.FloatTensor,
22322232
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
22332233
temb: Optional[torch.FloatTensor] = None,
2234+
upsample_size: Optional[int] = None,
22342235
image_only_indicator: Optional[torch.Tensor] = None,
22352236
) -> torch.FloatTensor:
22362237
for resnet in self.resnets:
@@ -2272,7 +2273,7 @@ def custom_forward(*inputs):
22722273

22732274
if self.upsamplers is not None:
22742275
for upsampler in self.upsamplers:
2275-
hidden_states = upsampler(hidden_states)
2276+
hidden_states = upsampler(hidden_states, upsample_size)
22762277

22772278
return hidden_states
22782279

@@ -2341,6 +2342,7 @@ def forward(
23412342
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
23422343
temb: Optional[torch.FloatTensor] = None,
23432344
encoder_hidden_states: Optional[torch.FloatTensor] = None,
2345+
upsample_size: Optional[int] = None,
23442346
image_only_indicator: Optional[torch.Tensor] = None,
23452347
) -> torch.FloatTensor:
23462348
for resnet, attn in zip(self.resnets, self.attentions):
@@ -2390,6 +2392,6 @@ def custom_forward(*inputs):
23902392

23912393
if self.upsamplers is not None:
23922394
for upsampler in self.upsamplers:
2393-
hidden_states = upsampler(hidden_states)
2395+
hidden_states = upsampler(hidden_states, upsample_size)
23942396

23952397
return hidden_states

src/diffusers/models/unet_spatio_temporal_condition.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,20 @@ def forward(
381381
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
382382
a `tuple` is returned where the first element is the sample tensor.
383383
"""
384+
# By default samples have to be AT least a multiple of the overall upsampling factor.
385+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
386+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
387+
# on the fly if necessary.
388+
default_overall_up_factor = 2**self.num_upsamplers
389+
390+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
391+
forward_upsample_size = False
392+
upsample_size = None
393+
394+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
395+
logger.info("Forward upsample size to force interpolation output size.")
396+
forward_upsample_size = True
397+
384398
# 1. time
385399
timesteps = timestep
386400
if not torch.is_tensor(timesteps):
@@ -456,22 +470,31 @@ def forward(
456470

457471
# 5. up
458472
for i, upsample_block in enumerate(self.up_blocks):
473+
is_final_block = i == len(self.up_blocks) - 1
474+
459475
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
460476
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
461477

478+
# if we have not reached the final block and need to forward the
479+
# upsample size, we do it here
480+
if not is_final_block and forward_upsample_size:
481+
upsample_size = down_block_res_samples[-1].shape[2:]
482+
462483
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
463484
sample = upsample_block(
464485
hidden_states=sample,
465486
temb=emb,
466487
res_hidden_states_tuple=res_samples,
467488
encoder_hidden_states=encoder_hidden_states,
489+
upsample_size=upsample_size,
468490
image_only_indicator=image_only_indicator,
469491
)
470492
else:
471493
sample = upsample_block(
472494
hidden_states=sample,
473495
temb=emb,
474496
res_hidden_states_tuple=res_samples,
497+
upsample_size=upsample_size,
475498
image_only_indicator=image_only_indicator,
476499
)
477500

0 commit comments

Comments
 (0)