Skip to content
Merged
2 changes: 1 addition & 1 deletion src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class conditioning with `class_embed_type` equal to `None`.
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,12 @@ def __init__(
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int)
is_unet_sample_size_less_64 = (
hasattr(unet.config, "sample_size")
and self._is_unet_config_sample_size_int
and unet.config.sample_size < 64
)
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
Expand Down Expand Up @@ -902,8 +907,18 @@ def __call__(
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if not height or not width:
height = (
self.unet.config.sample_size
if self._is_unet_config_sample_size_int
else self.unet.config.sample_size[0]
)
width = (
self.unet.config.sample_size
if self._is_unet_config_sample_size_int
else self.unet.config.sample_size[1]
)
height, width = height * self.vae_scale_factor, width * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks

# 1. Check inputs. Raise error if not correct
Expand Down
8 changes: 8 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,14 @@ def callback_on_step_end(pipe, i, t, callback_kwargs):
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)

def test_pipeline_accept_tuple_type_unet_sample_size(self):
# the purpose of this test is to see whether the pipeline would accept a unet with the tuple-typed sample size
sd_repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
sample_size = [60, 80]
customised_unet = UNet2DConditionModel(sample_size=sample_size)
pipe = StableDiffusionPipeline.from_pretrained(sd_repo_id, unet=customised_unet)
assert pipe.unet.config.sample_size == sample_size
Comment on lines +843 to +849
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use a smaller checkpoint?
https://huggingface.co/hf-internal-testing/tiny-sd-pipe



@slow
@require_torch_gpu
Expand Down
Loading