Skip to content

Commit da06a2c

Browse files
committed
-device cast
1 parent c84e6d7 commit da06a2c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def forward(
589589

590590
# Match t_embedder output dtype to x for layerwise casting compatibility
591591
adaln_input = t.type_as(x)
592-
x[torch.cat(x_inner_pad_mask).to(x.device)] = self.x_pad_token.to(x.device)
592+
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
593593
x = list(x.split(x_item_seqlens, dim=0))
594594
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
595595

@@ -613,7 +613,7 @@ def forward(
613613

614614
cap_feats = torch.cat(cap_feats, dim=0)
615615
cap_feats = self.cap_embedder(cap_feats)
616-
cap_feats[torch.cat(cap_inner_pad_mask).to(cap_feats.device)] = self.cap_pad_token.to(cap_feats.device)
616+
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
617617
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
618618
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
619619

0 commit comments

Comments
 (0)