diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index 5972505f2897..d05af686dede 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -97,6 +97,7 @@ def __init__( out_channels: int = 3, center_input_sample: bool = False, time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, freq_shift: int = 0, flip_sin_to_cos: bool = True, down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), @@ -122,7 +123,7 @@ def __init__( super().__init__() self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 # Check inputs if len(down_block_types) != len(up_block_types):