We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8f1d10f commit 3ee5f53Copy full SHA for 3ee5f53
diffsynth/models/z_image_dit.py
@@ -8,6 +8,7 @@
8
9
from torch.nn import RMSNorm
10
from ..core.attention import attention_forward
11
+from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
12
from ..core.gradient import gradient_checkpoint_forward
13
14
@@ -274,7 +275,10 @@ def __call__(self, ids: torch.Tensor):
274
275
result = []
276
for i in range(len(self.axes_dims)):
277
index = ids[:, i]
- result.append(self.freqs_cis[i][index])
278
+ if IS_NPU_AVAILABLE:
279
+ result.append(self.freqs_cis[i][index])
280
+ else:
281
+ result.append(torch.index_select(self.freqs_cis[i], 0, index))
282
return torch.cat(result, dim=-1)
283
284
0 commit comments