Skip to content

Commit 3810675

Browse files
authored
Merge pull request #1176 from Feng0w0/z-image-rope
[model][NPU]: Z-image model support NPU
2 parents 00f2d1a + c1c9a48 commit 3810675

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

diffsynth/models/z_image_dit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from torch.nn import RMSNorm
1010
from ..core.attention import attention_forward
11+
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
1112
from ..core.gradient import gradient_checkpoint_forward
1213

1314

@@ -315,7 +316,10 @@ def __call__(self, ids: torch.Tensor):
315316
result = []
316317
for i in range(len(self.axes_dims)):
317318
index = ids[:, i]
318-
result.append(self.freqs_cis[i][index])
319+
if IS_NPU_AVAILABLE:
320+
result.append(torch.index_select(self.freqs_cis[i], 0, index))
321+
else:
322+
result.append(self.freqs_cis[i][index])
319323
return torch.cat(result, dim=-1)
320324

321325

0 commit comments

Comments
 (0)