Skip to content

Commit 9039db4

Browse files
committed
use diffusers timesteps embedding; diff: 0.10205078125
1 parent e713660 commit 9039db4

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...utils import is_torch_version
2525
from ..attention import FeedForward
2626
from ..attention_processor import Attention, AttentionProcessor
27-
from ..embeddings import get_1d_rotary_pos_embed
27+
from ..embeddings import get_1d_rotary_pos_embed, get_timestep_embedding
2828
from ..modeling_outputs import Transformer2DModelOutput
2929
from ..modeling_utils import ModelMixin
3030
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -219,7 +219,8 @@ def __init__(
219219
)
220220

221221
def forward(self, t):
222-
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
222+
# t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
223+
t_freq = get_timestep_embedding(t, self.frequency_embedding_size, flip_sin_to_cos=True, max_period=self.max_period, downscale_freq_shift=0).type(self.mlp[0].weight.dtype)
223224
t_emb = self.mlp(t_freq)
224225
return t_emb
225226

@@ -231,24 +232,22 @@ def __init__(
231232
attention_head_dim: int,
232233
mlp_width_ratio: str = 4.0,
233234
mlp_drop_rate: float = 0.0,
234-
qkv_bias: bool = True,
235+
attention_bias: bool = True,
235236
) -> None:
236237
super().__init__()
237238

238239
hidden_size = num_attention_heads * attention_head_dim
239240

240241
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
241-
242242
self.attn = Attention(
243243
query_dim=hidden_size,
244244
cross_attention_dim=None,
245245
heads=num_attention_heads,
246246
dim_head=attention_head_dim,
247-
bias=True,
247+
bias=attention_bias,
248248
)
249249

250250
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
251-
252251
self.mlp = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="silu", dropout=mlp_drop_rate)
253252

254253
self.adaLN_modulation = nn.Sequential(
@@ -286,8 +285,8 @@ def __init__(
286285
num_layers: int,
287286
mlp_width_ratio: float = 4.0,
288287
mlp_drop_rate: float = 0.0,
289-
qkv_bias: bool = True,
290-
):
288+
attention_bias: bool = True,
289+
) -> None:
291290
super().__init__()
292291

293292
self.refiner_blocks = nn.ModuleList(
@@ -297,7 +296,7 @@ def __init__(
297296
attention_head_dim=attention_head_dim,
298297
mlp_width_ratio=mlp_width_ratio,
299298
mlp_drop_rate=mlp_drop_rate,
300-
qkv_bias=qkv_bias,
299+
attention_bias=attention_bias,
301300
)
302301
for _ in range(num_layers)
303302
]
@@ -308,7 +307,7 @@ def forward(
308307
hidden_states: torch.Tensor,
309308
temb: torch.Tensor,
310309
attention_mask: Optional[torch.Tensor] = None,
311-
):
310+
) -> None:
312311
self_attn_mask = None
313312
if attention_mask is not None:
314313
batch_size = attention_mask.shape[0]
@@ -334,13 +333,15 @@ def __init__(
334333
num_layers: int,
335334
mlp_ratio: float = 4.0,
336335
mlp_drop_rate: float = 0.0,
337-
qkv_bias: bool = True,
338-
):
336+
attention_bias: bool = True,
337+
) -> None:
339338
super().__init__()
340339

341340
hidden_size = num_attention_heads * attention_head_dim
342341

343342
self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True)
343+
# self.time_embed = TimestepEmbedder(hidden_size, nn.SiLU)
344+
# self.context_embed = TextProjection(in_channels, hidden_size, nn.SiLU)
344345
self.t_embedder = TimestepEmbedder(hidden_size, nn.SiLU)
345346
self.c_embedder = TextProjection(in_channels, hidden_size, nn.SiLU)
346347

@@ -350,7 +351,7 @@ def __init__(
350351
num_layers=num_layers,
351352
mlp_width_ratio=mlp_ratio,
352353
mlp_drop_rate=mlp_drop_rate,
353-
qkv_bias=qkv_bias,
354+
attention_bias=attention_bias,
354355
)
355356

356357
def forward(
@@ -360,6 +361,7 @@ def forward(
360361
attention_mask: Optional[torch.LongTensor] = None,
361362
) -> torch.Tensor:
362363
original_dtype = hidden_states.dtype
364+
# temb = self.time_embed(timestep)
363365
temb = self.t_embedder(timestep)
364366

365367
if attention_mask is None:
@@ -369,6 +371,7 @@ def forward(
369371
pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
370372
pooled_projections = pooled_projections.to(original_dtype)
371373

374+
# pooled_projections = self.context_embed(pooled_projections)
372375
pooled_projections = self.c_embedder(pooled_projections)
373376
emb = temb + pooled_projections
374377

0 commit comments

Comments
 (0)