diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index f7c73231725d..48b731406191 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -326,6 +326,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig Whether to use elementwise affinity in the normalization layer. norm_eps (`float`, defaults to `1e-6`): The epsilon value for the normalization layer. + qk_norm (`str`, *optional*, defaults to `None`): + The normalization to use for the query and key. + timestep_scale (`float`, defaults to `1.0`): + The scale to use for the timesteps. """ _supports_gradient_checkpointing = True @@ -355,6 +359,7 @@ def __init__( guidance_embeds: bool = False, guidance_embeds_scale: float = 0.1, qk_norm: Optional[str] = None, + timestep_scale: float = 1.0, ) -> None: super().__init__() diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 76934d055c56..6093fd836aad 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -938,6 +938,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + timestep = timestep * self.transformer.config.timestep_scale # predict noise model_output noise_pred = self.transformer(