Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
def prepare_causal_attention_mask(
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
) -> torch.Tensor:
seq_len = num_frames * height_width
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // height_width
mask[i, : (i_frame + 1) * height_width] = 0
indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device)
indices_blocks = indices.repeat_interleave(height_width)
x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy")
mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype)

if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
Expand Down
26 changes: 26 additions & 0 deletions tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from diffusers import AutoencoderKLHunyuanVideo
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
Expand Down Expand Up @@ -182,3 +183,28 @@ def test_forward_with_norm_groups(self):
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass

def test_prepare_causal_attention_mask(self):
def prepare_causal_attention_mask_orig(
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
) -> torch.Tensor:
seq_len = num_frames * height_width
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // height_width
mask[i, : (i_frame + 1) * height_width] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask

# test with some odd shapes
original_mask = prepare_causal_attention_mask_orig(
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
)
new_mask = prepare_causal_attention_mask(
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
)
self.assertTrue(
torch.allclose(original_mask, new_mask),
"Causal attention mask should be the same",
)