diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index 22b833734f0f..089e641d8852 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -36,11 +36,11 @@ def prepare_causal_attention_mask( num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None ) -> torch.Tensor: - seq_len = num_frames * height_width - mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) - for i in range(seq_len): - i_frame = i // height_width - mask[i, : (i_frame + 1) * height_width] = 0 + indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device) + indices_blocks = indices.repeat_interleave(height_width) + x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy") + mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype) + if batch_size is not None: mask = mask.unsqueeze(0).expand(batch_size, -1, -1) return mask diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py index 7b7901a6fd94..00d4b8ed2b5f 100644 --- a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py @@ -18,6 +18,7 @@ import torch from diffusers import AutoencoderKLHunyuanVideo +from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, @@ -182,3 +183,28 @@ def test_forward_with_norm_groups(self): @unittest.skip("Unsupported test.") def test_outputs_equivalence(self): pass + + def test_prepare_causal_attention_mask(self): + def prepare_causal_attention_mask_orig( + num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None + ) -> torch.Tensor: + seq_len = num_frames * height_width + mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) + for i in range(seq_len): + i_frame = i // height_width + mask[i, : (i_frame + 1) * height_width] = 0 + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + # test with some odd shapes + original_mask = prepare_causal_attention_mask_orig( + num_frames=31, height_width=111, dtype=torch.float32, device=torch_device + ) + new_mask = prepare_causal_attention_mask( + num_frames=31, height_width=111, dtype=torch.float32, device=torch_device + ) + self.assertTrue( + torch.allclose(original_mask, new_mask), + "Causal attention mask should be the same", + )