Skip to content

Commit c84e6d7

Browse files
committed
compute_dtype
1 parent ba23ad8 commit c84e6d7

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,11 @@ def timestep_embedding(t, dim, max_period=10000):
7070
def forward(self, t):
7171
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
7272
weight_dtype = self.mlp[0].weight.dtype
73+
compute_dtype = getattr(self.mlp[0], "compute_dtype", None)
7374
if weight_dtype.is_floating_point:
7475
t_freq = t_freq.to(weight_dtype)
76+
elif compute_dtype is not None:
77+
t_freq = t_freq.to(compute_dtype)
7578
t_emb = self.mlp(t_freq)
7679
return t_emb
7780

@@ -586,7 +589,7 @@ def forward(
586589

587590
# Match t_embedder output dtype to x for layerwise casting compatibility
588591
adaln_input = t.type_as(x)
589-
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
592+
x[torch.cat(x_inner_pad_mask).to(x.device)] = self.x_pad_token.to(x.device)
590593
x = list(x.split(x_item_seqlens, dim=0))
591594
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
592595

@@ -610,7 +613,7 @@ def forward(
610613

611614
cap_feats = torch.cat(cap_feats, dim=0)
612615
cap_feats = self.cap_embedder(cap_feats)
613-
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
616+
cap_feats[torch.cat(cap_inner_pad_mask).to(cap_feats.device)] = self.cap_pad_token.to(cap_feats.device)
614617
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
615618
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
616619

0 commit comments

Comments
 (0)