Skip to content

Commit e713660

Browse files
committed
refactor rope; diff: 0.14990234375; reason and fix: create rope grid on cpu and move to device
Note: The following line diverges from original behaviour. We create the grid on the device, whereas original implementation creates it on CPU and then moves it to device. This results in numerical differences in layerwise debugging outputs, but visually it is the same.
1 parent d9ae8de commit e713660

File tree

2 files changed

+90
-290
lines changed

2 files changed

+90
-290
lines changed

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ...utils import is_torch_version
2525
from ..attention import FeedForward
2626
from ..attention_processor import Attention, AttentionProcessor
27+
from ..embeddings import get_1d_rotary_pos_embed
2728
from ..modeling_outputs import Transformer2DModelOutput
2829
from ..modeling_utils import ModelMixin
2930
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -138,26 +139,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
138139
class PatchEmbed(nn.Module):
139140
def __init__(
140141
self,
141-
patch_size=16,
142-
in_chans=3,
143-
embed_dim=768,
144-
norm_layer=None,
145-
flatten=True,
146-
bias=True,
147-
):
142+
patch_size: Union[int, Tuple[int, int, int]] = 16,
143+
in_chans: int = 3,
144+
embed_dim: int = 768,
145+
) -> None:
148146
super().__init__()
149147

150-
patch_size = tuple(patch_size)
151-
self.flatten = flatten
152-
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
153-
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
148+
patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
149+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
154150

155-
def forward(self, x):
156-
x = self.proj(x)
157-
if self.flatten:
158-
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
159-
x = self.norm(x)
160-
return x
151+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
152+
hidden_states = self.proj(hidden_states)
153+
hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
154+
return hidden_states
161155

162156

163157
class TextProjection(nn.Module):
@@ -384,6 +378,39 @@ def forward(
384378
return hidden_states
385379

386380

381+
class HunyuanVideoRotaryPosEmbed(nn.Module):
382+
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
383+
super().__init__()
384+
385+
self.patch_size = patch_size
386+
self.patch_size_t = patch_size_t
387+
self.rope_dim = rope_dim
388+
self.theta = theta
389+
390+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
391+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
392+
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
393+
394+
axes_grids = []
395+
for i in range(3):
396+
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
397+
# original implementation creates it on CPU and then moves it to device. This results in numerical
398+
# differences in layerwise debugging outputs, but visually it is the same.
399+
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
400+
axes_grids.append(grid)
401+
grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
402+
grid = torch.stack(grid, dim=0) # [3, W, H, T]
403+
404+
freqs = []
405+
for i in range(3):
406+
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
407+
freqs.append(freq)
408+
409+
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
410+
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
411+
return freqs_cos, freqs_sin
412+
413+
387414
class HunyuanVideoSingleTransformerBlock(nn.Module):
388415
def __init__(
389416
self,
@@ -546,12 +573,12 @@ def __init__(
546573
guidance_embeds: bool = True,
547574
text_embed_dim: int = 4096,
548575
text_embed_dim_2: int = 768,
576+
rope_theta: float = 256.0,
549577
) -> None:
550578
super().__init__()
551579

552580
inner_dim = num_attention_heads * attention_head_dim
553581
out_channels = out_channels or in_channels
554-
self.rope_dim_list = rope_dim_list
555582

556583
# image projection
557584
self.img_in = PatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
@@ -570,6 +597,9 @@ def __init__(
570597
# guidance modulation
571598
self.guidance_in = TimestepEmbedder(inner_dim, nn.SiLU)
572599

600+
# 3. RoPE
601+
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_dim_list, rope_theta)
602+
573603
self.transformer_blocks = nn.ModuleList(
574604
[
575605
HunyuanVideoTransformerBlock(
@@ -664,8 +694,6 @@ def forward(
664694
encoder_hidden_states: torch.Tensor,
665695
encoder_attention_mask: torch.Tensor,
666696
encoder_hidden_states_2: torch.Tensor,
667-
freqs_cos: Optional[torch.Tensor] = None,
668-
freqs_sin: Optional[torch.Tensor] = None,
669697
guidance: torch.Tensor = None,
670698
return_dict: bool = True,
671699
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
@@ -676,6 +704,8 @@ def forward(
676704
post_patch_height = height // p
677705
post_patch_width = width // p
678706

707+
image_rotary_emb = self.rope(hidden_states)
708+
679709
temb = self.time_in(timestep)
680710
temb = temb + self.vector_in(encoder_hidden_states_2)
681711
temb = temb + self.guidance_in(guidance)
@@ -691,15 +721,14 @@ def forward(
691721
else lambda x: x
692722
)
693723

694-
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
695724
for _, block in enumerate(self.transformer_blocks):
696725
hidden_states, encoder_hidden_states = block_forward(block)(
697-
hidden_states, encoder_hidden_states, temb, freqs_cis
726+
hidden_states, encoder_hidden_states, temb, image_rotary_emb
698727
)
699728

700729
for block in self.single_transformer_blocks:
701730
hidden_states, encoder_hidden_states = block_forward(block)(
702-
hidden_states, encoder_hidden_states, temb, freqs_cis
731+
hidden_states, encoder_hidden_states, temb, image_rotary_emb
703732
)
704733

705734
hidden_states = self.norm_out(hidden_states, temb)

0 commit comments

Comments
 (0)