diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 358759164b9e..49a45c7c0525 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -40,6 +40,62 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class SkyReelsV2AdaLayerNorm(nn.Module): + r""" + Norm layer modified to incorporate timestep embeddings. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + output_dim (`int`, *optional*): + norm_elementwise_affine (`bool`, defaults to `False): + norm_eps (`bool`, defaults to `False`): + """ + + def __init__( + self, + embedding_dim: int, + output_dim: Optional[int] = None, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + ): + super().__init__() + + output_dim = output_dim or embedding_dim * 2 + + self.linear = nn.Linear(embedding_dim, output_dim) + self.linear.weight.data[:embedding_dim, :] = torch.eye(embedding_dim) + self.linear.weight.data[embedding_dim:, :] = torch.eye(embedding_dim) + self.norm = FP32LayerNorm(embedding_dim, norm_eps, norm_elementwise_affine) + + def forward(self, x: torch.Tensor, temb: torch.Tensor) -> torch.Tensor: + if temb.ndim == 2: + # If temb is 2D, we assume it has 1-D time embedding values for each batch. + # For models: + # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers + # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers + # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers + # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers + # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers + # 2D temb: (batch, embedding_dim) + temb = self.linear(temb.unsqueeze(1)) # (batch, 1, embedding_dim * 2) + shift, scale = temb.chunk(2, dim=2) + elif temb.ndim == 3: + # If temb is 3D, we assume it has 2-D time embedding values for each batch. + # Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing. + # For models: + # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers + # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers + # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers + # 3D temb: (batch, num_latent_frames * post_patch_height * post_patch_width, embedding_dim) + temb = self.linear( + temb + ) # (batch, num_latent_frames * post_patch_height * post_patch_width, embedding_dim * 2) + shift, scale = temb.chunk(2, dim=2) + + x = self.norm(x) * (1 + scale) + shift + return x + + def _get_qkv_projections( attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor ): @@ -559,9 +615,10 @@ class SkyReelsV2Transformer3DModel( _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] - _no_split_modules = ["SkyReelsV2TransformerBlock"] + _no_split_modules = ["SkyReelsV2TransformerBlock", "SkyReelsV2AdaLayerNorm"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] - _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q", "scale_shift_table"] + _keys_to_ignore_on_load_missing = ["norm_out.linear.weight", "norm_out.linear.bias"] _repeated_blocks = ["SkyReelsV2TransformerBlock"] @register_to_config @@ -617,9 +674,13 @@ def __init__( ) # 4. Output norm & projection - self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.norm_out = SkyReelsV2AdaLayerNorm( + embedding_dim=inner_dim, + output_dim=2 * inner_dim, + norm_elementwise_affine=False, + norm_eps=eps, + ) self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) - self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) if inject_sample_info: self.fps_embedding = nn.Embedding(2, inner_dim) @@ -732,34 +793,7 @@ def forward( causal_mask, ) - if temb.dim() == 2: - # If temb is 2D, we assume it has time 1-D time embedding values for each batch. - # For models: - # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers - # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers - # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers - # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers - # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers - shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) - elif temb.dim() == 3: - # If temb is 3D, we assume it has 2-D time embedding values for each batch. - # Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing. - # For models: - # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers - # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers - # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers - shift, scale = (self.scale_shift_table.unsqueeze(2) + temb.unsqueeze(1)).chunk(2, dim=1) - shift, scale = shift.squeeze(1), scale.squeeze(1) - - # Move the shift and scale tensors to the same device as hidden_states. - # When using multi-GPU inference via accelerate these will be on the - # first device rather than the last device, which hidden_states ends up - # on. - shift = shift.to(hidden_states.device) - scale = scale.to(hidden_states.device) - - hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) - + hidden_states = self.norm_out(hidden_states, temb=temb) hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape(