@@ -70,8 +70,11 @@ def timestep_embedding(t, dim, max_period=10000):
7070 def forward (self , t ):
7171 t_freq = self .timestep_embedding (t , self .frequency_embedding_size )
7272 weight_dtype = self .mlp [0 ].weight .dtype
73+ compute_dtype = getattr (self .mlp [0 ], "compute_dtype" , None )
7374 if weight_dtype .is_floating_point :
7475 t_freq = t_freq .to (weight_dtype )
76+ elif compute_dtype is not None :
77+ t_freq = t_freq .to (compute_dtype )
7578 t_emb = self .mlp (t_freq )
7679 return t_emb
7780
@@ -586,7 +589,7 @@ def forward(
586589
587590 # Match t_embedder output dtype to x for layerwise casting compatibility
588591 adaln_input = t .type_as (x )
589- x [torch .cat (x_inner_pad_mask )] = self .x_pad_token
592+ x [torch .cat (x_inner_pad_mask ). to ( x . device ) ] = self .x_pad_token . to ( x . device )
590593 x = list (x .split (x_item_seqlens , dim = 0 ))
591594 x_freqs_cis = list (self .rope_embedder (torch .cat (x_pos_ids , dim = 0 )).split (x_item_seqlens , dim = 0 ))
592595
@@ -610,7 +613,7 @@ def forward(
610613
611614 cap_feats = torch .cat (cap_feats , dim = 0 )
612615 cap_feats = self .cap_embedder (cap_feats )
613- cap_feats [torch .cat (cap_inner_pad_mask )] = self .cap_pad_token
616+ cap_feats [torch .cat (cap_inner_pad_mask ). to ( cap_feats . device ) ] = self .cap_pad_token . to ( cap_feats . device )
614617 cap_feats = list (cap_feats .split (cap_item_seqlens , dim = 0 ))
615618 cap_freqs_cis = list (self .rope_embedder (torch .cat (cap_pos_ids , dim = 0 )).split (cap_item_seqlens , dim = 0 ))
616619
0 commit comments