We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 00f2d1a + c1c9a48 commit 3810675Copy full SHA for 3810675
diffsynth/models/z_image_dit.py
@@ -8,6 +8,7 @@
8
9
from torch.nn import RMSNorm
10
from ..core.attention import attention_forward
11
+from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
12
from ..core.gradient import gradient_checkpoint_forward
13
14
@@ -315,7 +316,10 @@ def __call__(self, ids: torch.Tensor):
315
316
result = []
317
for i in range(len(self.axes_dims)):
318
index = ids[:, i]
- result.append(self.freqs_cis[i][index])
319
+ if IS_NPU_AVAILABLE:
320
+ result.append(torch.index_select(self.freqs_cis[i], 0, index))
321
+ else:
322
+ result.append(self.freqs_cis[i][index])
323
return torch.cat(result, dim=-1)
324
325
0 commit comments