Skip to content

Commit b9ec619

Browse files
author
白超
committed
bugfix for npu not support float64
1 parent 8421c14 commit b9ec619

21 files changed

+62
-62
lines changed

examples/community/fresco_v2v.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,11 @@ def forward(
403403
if not torch.is_tensor(timesteps):
404404
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
405405
# This would be a good case for the `match` statement (Python 3.10+)
406-
is_mps = sample.device.type == "mps"
406+
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
407407
if isinstance(timestep, float):
408-
dtype = torch.float32 if is_mps else torch.float64
408+
dtype = torch.float32 if is_mps_or_npu else torch.float64
409409
else:
410-
dtype = torch.int32 if is_mps else torch.int64
410+
dtype = torch.int32 if is_mps_or_npu else torch.int64
411411
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
412412
elif len(timesteps.shape) == 0:
413413
timesteps = timesteps[None].to(sample.device)

examples/community/matryoshka.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2805,11 +2805,11 @@ def get_time_embed(
28052805
if not torch.is_tensor(timesteps):
28062806
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
28072807
# This would be a good case for the `match` statement (Python 3.10+)
2808-
is_mps = sample.device.type == "mps"
2808+
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
28092809
if isinstance(timestep, float):
2810-
dtype = torch.float32 if is_mps else torch.float64
2810+
dtype = torch.float32 if is_mps_or_npu else torch.float64
28112811
else:
2812-
dtype = torch.int32 if is_mps else torch.int64
2812+
dtype = torch.int32 if is_mps_or_npu else torch.int64
28132813
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
28142814
elif len(timesteps.shape) == 0:
28152815
timesteps = timesteps[None].to(sample.device)

examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,11 +1030,11 @@ def __call__(
10301030
if not torch.is_tensor(current_timestep):
10311031
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
10321032
# This would be a good case for the `match` statement (Python 3.10+)
1033-
is_mps = latent_model_input.device.type == "mps"
1033+
is_mps_or_npu = latent_model_input.device.type == "mps" or latent_model_input.device.type == "npu"
10341034
if isinstance(current_timestep, float):
1035-
dtype = torch.float32 if is_mps else torch.float64
1035+
dtype = torch.float32 if is_mps_or_npu else torch.float64
10361036
else:
1037-
dtype = torch.int32 if is_mps else torch.int64
1037+
dtype = torch.int32 if is_mps_or_npu else torch.int64
10381038
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
10391039
elif len(current_timestep.shape) == 0:
10401040
current_timestep = current_timestep[None].to(latent_model_input.device)

examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,11 @@ def forward(
257257
if not torch.is_tensor(timesteps):
258258
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
259259
# This would be a good case for the `match` statement (Python 3.10+)
260-
is_mps = sample.device.type == "mps"
260+
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
261261
if isinstance(timestep, float):
262-
dtype = torch.float32 if is_mps else torch.float64
262+
dtype = torch.float32 if is_mps_or_npu else torch.float64
263263
else:
264-
dtype = torch.int32 if is_mps else torch.int64
264+
dtype = torch.int32 if is_mps_or_npu else torch.int64
265265
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
266266
elif len(timesteps.shape) == 0:
267267
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -739,11 +739,11 @@ def forward(
739739
if not torch.is_tensor(timesteps):
740740
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
741741
# This would be a good case for the `match` statement (Python 3.10+)
742-
is_mps = sample.device.type == "mps"
742+
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
743743
if isinstance(timestep, float):
744-
dtype = torch.float32 if is_mps else torch.float64
744+
dtype = torch.float32 if is_mps_or_npu else torch.float64
745745
else:
746-
dtype = torch.int32 if is_mps else torch.int64
746+
dtype = torch.int32 if is_mps_or_npu else torch.int64
747747
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
748748
elif len(timesteps.shape) == 0:
749749
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_sparsectrl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -670,11 +670,11 @@ def forward(
670670
if not torch.is_tensor(timesteps):
671671
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
672672
# This would be a good case for the `match` statement (Python 3.10+)
673-
is_mps = sample.device.type == "mps"
673+
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
674674
if isinstance(timestep, float):
675-
dtype = torch.float32 if is_mps else torch.float64
675+
dtype = torch.float32 if is_mps_or_npu else torch.float64
676676
else:
677-
dtype = torch.int32 if is_mps else torch.int64
677+
dtype = torch.int32 if is_mps_or_npu else torch.int64
678678
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
679679
elif len(timesteps.shape) == 0:
680680
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_xs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,11 +1087,11 @@ def forward(
10871087
if not torch.is_tensor(timesteps):
10881088
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
10891089
# This would be a good case for the `match` statement (Python 3.10+)
1090-
is_mps = sample.device.type == "mps"
1090+
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
10911091
if isinstance(timestep, float):
1092-
dtype = torch.float32 if is_mps else torch.float64
1092+
dtype = torch.float32 if is_mps_or_npu else torch.float64
10931093
else:
1094-
dtype = torch.int32 if is_mps else torch.int64
1094+
dtype = torch.int32 if is_mps_or_npu else torch.int64
10951095
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
10961096
elif len(timesteps.shape) == 0:
10971097
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -955,8 +955,8 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
955955
cos_out = []
956956
sin_out = []
957957
pos = ids.float()
958-
is_mps = ids.device.type == "mps"
959-
freqs_dtype = torch.float32 if is_mps else torch.float64
958+
is_mps_or_npu = ids.device.type == "mps" or ids.device.type == "npu"
959+
freqs_dtype = torch.float32 if is_mps_or_npu else torch.float64
960960
for i in range(n_axes):
961961
cos, sin = get_1d_rotary_pos_embed(
962962
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -914,11 +914,11 @@ def get_time_embed(
914914
if not torch.is_tensor(timesteps):
915915
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
916916
# This would be a good case for the `match` statement (Python 3.10+)
917-
is_mps = sample.device.type == "mps"
917+
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
918918
if isinstance(timestep, float):
919-
dtype = torch.float32 if is_mps else torch.float64
919+
dtype = torch.float32 if is_mps_or_npu else torch.float64
920920
else:
921-
dtype = torch.int32 if is_mps else torch.int64
921+
dtype = torch.int32 if is_mps_or_npu else torch.int64
922922
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
923923
elif len(timesteps.shape) == 0:
924924
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/unets/unet_3d_condition.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -623,11 +623,11 @@ def forward(
623623
if not torch.is_tensor(timesteps):
624624
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
625625
# This would be a good case for the `match` statement (Python 3.10+)
626-
is_mps = sample.device.type == "mps"
626+
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
627627
if isinstance(timestep, float):
628-
dtype = torch.float32 if is_mps else torch.float64
628+
dtype = torch.float32 if is_mps_or_npu else torch.float64
629629
else:
630-
dtype = torch.int32 if is_mps else torch.int64
630+
dtype = torch.int32 if is_mps_or_npu else torch.int64
631631
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
632632
elif len(timesteps.shape) == 0:
633633
timesteps = timesteps[None].to(sample.device)

0 commit comments

Comments
 (0)