Skip to content

Commit 21d8130

Browse files
committed
fixed stack issue
1 parent aae03cf commit 21d8130

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,11 @@ def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
304304
def __call__(self, ids: torch.Tensor):
305305
assert ids.ndim == 2
306306
assert ids.shape[-1] == len(self.axes_dims)
307+
device = ids.device
307308

308309
if self.freqs_cis is None:
309310
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
310-
self.freqs_cis = [freqs_cis.cuda() for freqs_cis in self.freqs_cis]
311+
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
311312

312313
result = []
313314
for i in range(len(self.axes_dims)):
@@ -596,6 +597,7 @@ def forward(
596597
x_freqs_cis[i] = torch.cat([freqs_item, freqs_pad_tensor.repeat(pad_len, 1)])
597598
x_attn_mask[i, seq_len:] = 0
598599
x = torch.stack(x)
600+
x_freqs_cis = torch.stack(x_freqs_cis)
599601

600602
for layer in self.noise_refiner:
601603
x = layer(
@@ -638,6 +640,7 @@ def forward(
638640
cap_freqs_cis[i] = torch.cat([freqs_item, freqs_pad_tensor.repeat(pad_len, 1)])
639641
cap_attn_mask[i, seq_len:] = 0
640642
cap_feats = torch.stack(cap_feats)
643+
cap_freqs_cis = torch.stack(cap_freqs_cis)
641644

642645
for layer in self.context_refiner:
643646
cap_feats = layer(
@@ -680,6 +683,7 @@ def forward(
680683
unified_freqs_cis[i] = torch.cat([freqs_item, freqs_pad_tensor.repeat(pad_len, 1)])
681684
unified_attn_mask[i, seq_len:] = 0
682685
unified = torch.stack(unified)
686+
unified_freqs_cis = torch.stack(unified_freqs_cis)
683687

684688
for layer in self.layers:
685689
unified = layer(

0 commit comments

Comments
 (0)