Skip to content

Commit 8fe6408

Browse files
committed
is_mps is_npu
1 parent 3454384 commit 8fe6408

21 files changed

+83
-62
lines changed

examples/community/fresco_v2v.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,12 @@ def forward(
403403
if not torch.is_tensor(timesteps):
404404
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
405405
# This would be a good case for the `match` statement (Python 3.10+)
406-
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
406+
is_mps = sample.device.type == "mps"
407+
is_npu = sample.device.type == "npu"
407408
if isinstance(timestep, float):
408-
dtype = torch.float32 if is_mps_or_npu else torch.float64
409+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
409410
else:
410-
dtype = torch.int32 if is_mps_or_npu else torch.int64
411+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
411412
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
412413
elif len(timesteps.shape) == 0:
413414
timesteps = timesteps[None].to(sample.device)

examples/community/matryoshka.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2805,11 +2805,12 @@ def get_time_embed(
28052805
if not torch.is_tensor(timesteps):
28062806
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
28072807
# This would be a good case for the `match` statement (Python 3.10+)
2808-
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
2808+
is_mps = sample.device.type == "mps"
2809+
is_npu = sample.device.type == "npu"
28092810
if isinstance(timestep, float):
2810-
dtype = torch.float32 if is_mps_or_npu else torch.float64
2811+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
28112812
else:
2812-
dtype = torch.int32 if is_mps_or_npu else torch.int64
2813+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
28132814
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
28142815
elif len(timesteps.shape) == 0:
28152816
timesteps = timesteps[None].to(sample.device)

examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,11 +1030,12 @@ def __call__(
10301030
if not torch.is_tensor(current_timestep):
10311031
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
10321032
# This would be a good case for the `match` statement (Python 3.10+)
1033-
is_mps_or_npu = latent_model_input.device.type == "mps" or latent_model_input.device.type == "npu"
1033+
is_mps = latent_model_input.device.type == "mps"
1034+
is_npu = latent_model_input.device.type == "npu"
10341035
if isinstance(current_timestep, float):
1035-
dtype = torch.float32 if is_mps_or_npu else torch.float64
1036+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
10361037
else:
1037-
dtype = torch.int32 if is_mps_or_npu else torch.int64
1038+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
10381039
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
10391040
elif len(current_timestep.shape) == 0:
10401041
current_timestep = current_timestep[None].to(latent_model_input.device)

examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,11 +257,12 @@ def forward(
257257
if not torch.is_tensor(timesteps):
258258
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
259259
# This would be a good case for the `match` statement (Python 3.10+)
260-
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
260+
is_mps = sample.device.type == "mps"
261+
is_npu = sample.device.type == "npu"
261262
if isinstance(timestep, float):
262-
dtype = torch.float32 if is_mps_or_npu else torch.float64
263+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
263264
else:
264-
dtype = torch.int32 if is_mps_or_npu else torch.int64
265+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
265266
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
266267
elif len(timesteps.shape) == 0:
267268
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -739,11 +739,12 @@ def forward(
739739
if not torch.is_tensor(timesteps):
740740
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
741741
# This would be a good case for the `match` statement (Python 3.10+)
742-
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
742+
is_mps = sample.device.type == "mps"
743+
is_npu = sample.device.type == "npu"
743744
if isinstance(timestep, float):
744-
dtype = torch.float32 if is_mps_or_npu else torch.float64
745+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
745746
else:
746-
dtype = torch.int32 if is_mps_or_npu else torch.int64
747+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
747748
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
748749
elif len(timesteps.shape) == 0:
749750
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_sparsectrl.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -670,11 +670,12 @@ def forward(
670670
if not torch.is_tensor(timesteps):
671671
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
672672
# This would be a good case for the `match` statement (Python 3.10+)
673-
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
673+
is_mps = sample.device.type == "mps"
674+
is_npu = sample.device.type == "npu"
674675
if isinstance(timestep, float):
675-
dtype = torch.float32 if is_mps_or_npu else torch.float64
676+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
676677
else:
677-
dtype = torch.int32 if is_mps_or_npu else torch.int64
678+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
678679
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
679680
elif len(timesteps.shape) == 0:
680681
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_union.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,10 +681,11 @@ def forward(
681681
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
682682
# This would be a good case for the `match` statement (Python 3.10+)
683683
is_mps = sample.device.type == "mps"
684+
is_npu = sample.device.type == "npu"
684685
if isinstance(timestep, float):
685-
dtype = torch.float32 if is_mps else torch.float64
686+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
686687
else:
687-
dtype = torch.int32 if is_mps else torch.int64
688+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
688689
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
689690
elif len(timesteps.shape) == 0:
690691
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_xs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,11 +1087,12 @@ def forward(
10871087
if not torch.is_tensor(timesteps):
10881088
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
10891089
# This would be a good case for the `match` statement (Python 3.10+)
1090-
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
1090+
is_mps = sample.device.type == "mps"
1091+
is_npu = sample.device.type == "npu"
10911092
if isinstance(timestep, float):
1092-
dtype = torch.float32 if is_mps_or_npu else torch.float64
1093+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
10931094
else:
1094-
dtype = torch.int32 if is_mps_or_npu else torch.int64
1095+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
10951096
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
10961097
elif len(timesteps.shape) == 0:
10971098
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -914,11 +914,12 @@ def get_time_embed(
914914
if not torch.is_tensor(timesteps):
915915
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
916916
# This would be a good case for the `match` statement (Python 3.10+)
917-
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
917+
is_mps = sample.device.type == "mps"
918+
is_npu = sample.device.type == "npu"
918919
if isinstance(timestep, float):
919-
dtype = torch.float32 if is_mps_or_npu else torch.float64
920+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
920921
else:
921-
dtype = torch.int32 if is_mps_or_npu else torch.int64
922+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
922923
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
923924
elif len(timesteps.shape) == 0:
924925
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/unets/unet_3d_condition.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -623,11 +623,12 @@ def forward(
623623
if not torch.is_tensor(timesteps):
624624
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
625625
# This would be a good case for the `match` statement (Python 3.10+)
626-
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
626+
is_mps = sample.device.type == "mps"
627+
is_npu = sample.device.type == "npu"
627628
if isinstance(timestep, float):
628-
dtype = torch.float32 if is_mps_or_npu else torch.float64
629+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
629630
else:
630-
dtype = torch.int32 if is_mps_or_npu else torch.int64
631+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
631632
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
632633
elif len(timesteps.shape) == 0:
633634
timesteps = timesteps[None].to(sample.device)

0 commit comments

Comments
 (0)