Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion diffsynth/models/z_image_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from torch.nn import RMSNorm
from ..core.attention import attention_forward
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE
from ..core.gradient import gradient_checkpoint_forward


Expand Down Expand Up @@ -274,7 +275,10 @@ def __call__(self, ids: torch.Tensor):
result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
if IS_NPU_AVAILABLE:
result.append(torch.index_select(self.freqs_cis[i], 0, index))
else:
result.append(self.freqs_cis[i][index])
Comment on lines +278 to +281
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using the global IS_NPU_AVAILABLE flag can lead to incorrect behavior if the system has an NPU but the model is running on a different device (like CUDA or CPU). It's more robust to check the device of the tensor itself to decide which code path to take.

By checking ids.device.type, you ensure that the NPU-specific code path is only taken when the model is actually running on an NPU device. This change would also make the import of IS_NPU_AVAILABLE at the top of the file unnecessary.

Suggested change
if IS_NPU_AVAILABLE:
result.append(torch.index_select(self.freqs_cis[i], 0, index))
else:
result.append(self.freqs_cis[i][index])
if ids.device.type == 'npu':
result.append(torch.index_select(self.freqs_cis[i], 0, index))
else:
result.append(self.freqs_cis[i][index])

return torch.cat(result, dim=-1)


Expand Down