Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/community/fresco_v2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,11 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
6 changes: 3 additions & 3 deletions examples/community/matryoshka.py
Original file line number Diff line number Diff line change
Expand Up @@ -2805,11 +2805,11 @@ def get_time_embed(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1030,11 +1030,11 @@ def __call__(
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
is_mps_or_npu = latent_model_input.device.type == "mps" or latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,11 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/controlnets/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,11 +739,11 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/controlnets/controlnet_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,11 +670,11 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/controlnets/controlnet_xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,11 +1087,11 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,8 +955,8 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
is_mps_or_npu = ids.device.type == "mps" or ids.device.type == "npu"
freqs_dtype = torch.float32 if is_mps_or_npu else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,11 +914,11 @@ def get_time_embed(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/unets/unet_3d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,11 +623,11 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/unets/unet_i2vgen_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,11 +574,11 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2113,11 +2113,11 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/models/unets/unet_spatio_temporal_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,11 +387,11 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,11 +767,11 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1162,11 +1162,11 @@ def forward(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
is_mps_or_npu = sample.device.type == "mps" or sample.device.type == "npu"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipelines/dit/pipeline_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,11 @@ def __call__(
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
is_mps_or_npu = latent_model_input.device.type == "mps" or latent_model_input.device.type == "npu"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(latent_model_input.device)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipelines/latte/pipeline_latte.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,11 +787,11 @@ def __call__(
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
is_mps_or_npu = latent_model_input.device.type == "mps" or latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipelines/lumina/pipeline_lumina.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,11 +803,11 @@ def __call__(
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
is_mps_or_npu = latent_model_input.device.type == "mps" or latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
current_timestep = torch.tensor(
[current_timestep],
dtype=dtype,
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,11 +797,11 @@ def __call__(
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
is_mps_or_npu = latent_model_input.device.type == "mps" or latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,11 +897,11 @@ def __call__(
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
is_mps_or_npu = latent_model_input.device.type == "mps" or latent_model_input.device.type == "npu"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
dtype = torch.float32 if is_mps_or_npu else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
dtype = torch.int32 if is_mps_or_npu else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
Expand Down
Loading