Skip to content

Commit d15c01d

Browse files
authored
Merge branch 'main' into 2025-license
2 parents 63194f9 + 75a636d commit d15c01d

26 files changed

+68
-47
lines changed

docs/source/en/api/pipelines/mochi.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ export_to_video(frames, "mochi.mp4", fps=30)
115115

116116
## Reproducing the results from the Genmo Mochi repo
117117

118-
The [Genmo Mochi implementation](https://github.com/genmoai/mochi/tree/main) uses different precision values for each stage in the inference process. The text encoder and VAE use `torch.float32`, while the DiT uses `torch.bfloat16` with the [attention kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) set to `EFFICIENT_ATTENTION`. Diffusers pipelines currently do not support setting different `dtypes` for different stages of the pipeline. In order to run inference in the same way as the the original implementation, please refer to the following example.
118+
The [Genmo Mochi implementation](https://github.com/genmoai/mochi/tree/main) uses different precision values for each stage in the inference process. The text encoder and VAE use `torch.float32`, while the DiT uses `torch.bfloat16` with the [attention kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) set to `EFFICIENT_ATTENTION`. Diffusers pipelines currently do not support setting different `dtypes` for different stages of the pipeline. In order to run inference in the same way as the original implementation, please refer to the following example.
119119

120120
<Tip>
121121
The original Mochi implementation zeros out empty prompts. However, enabling this option and placing the entire pipeline under autocast can lead to numerical overflows with the T5 text encoder.

examples/community/fresco_v2v.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,10 +404,11 @@ def forward(
404404
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
405405
# This would be a good case for the `match` statement (Python 3.10+)
406406
is_mps = sample.device.type == "mps"
407+
is_npu = sample.device.type == "npu"
407408
if isinstance(timestep, float):
408-
dtype = torch.float32 if is_mps else torch.float64
409+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
409410
else:
410-
dtype = torch.int32 if is_mps else torch.int64
411+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
411412
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
412413
elif len(timesteps.shape) == 0:
413414
timesteps = timesteps[None].to(sample.device)

examples/community/matryoshka.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2806,10 +2806,11 @@ def get_time_embed(
28062806
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
28072807
# This would be a good case for the `match` statement (Python 3.10+)
28082808
is_mps = sample.device.type == "mps"
2809+
is_npu = sample.device.type == "npu"
28092810
if isinstance(timestep, float):
2810-
dtype = torch.float32 if is_mps else torch.float64
2811+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
28112812
else:
2812-
dtype = torch.int32 if is_mps else torch.int64
2813+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
28132814
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
28142815
elif len(timesteps.shape) == 0:
28152816
timesteps = timesteps[None].to(sample.device)

examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,10 +1031,11 @@ def __call__(
10311031
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
10321032
# This would be a good case for the `match` statement (Python 3.10+)
10331033
is_mps = latent_model_input.device.type == "mps"
1034+
is_npu = latent_model_input.device.type == "npu"
10341035
if isinstance(current_timestep, float):
1035-
dtype = torch.float32 if is_mps else torch.float64
1036+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
10361037
else:
1037-
dtype = torch.int32 if is_mps else torch.int64
1038+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
10381039
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
10391040
elif len(current_timestep.shape) == 0:
10401041
current_timestep = current_timestep[None].to(latent_model_input.device)

examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,11 @@ def forward(
258258
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
259259
# This would be a good case for the `match` statement (Python 3.10+)
260260
is_mps = sample.device.type == "mps"
261+
is_npu = sample.device.type == "npu"
261262
if isinstance(timestep, float):
262-
dtype = torch.float32 if is_mps else torch.float64
263+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
263264
else:
264-
dtype = torch.int32 if is_mps else torch.int64
265+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
265266
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
266267
elif len(timesteps.shape) == 0:
267268
timesteps = timesteps[None].to(sample.device)

scripts/convert_consistency_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _download(url: str, root: str):
7373
loop.update(len(buffer))
7474

7575
if insecure_hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
76-
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
76+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not match")
7777

7878
return download_target
7979

src/diffusers/models/controlnets/controlnet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -740,10 +740,11 @@ def forward(
740740
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
741741
# This would be a good case for the `match` statement (Python 3.10+)
742742
is_mps = sample.device.type == "mps"
743+
is_npu = sample.device.type == "npu"
743744
if isinstance(timestep, float):
744-
dtype = torch.float32 if is_mps else torch.float64
745+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
745746
else:
746-
dtype = torch.int32 if is_mps else torch.int64
747+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
747748
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
748749
elif len(timesteps.shape) == 0:
749750
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_sparsectrl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,10 +671,11 @@ def forward(
671671
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
672672
# This would be a good case for the `match` statement (Python 3.10+)
673673
is_mps = sample.device.type == "mps"
674+
is_npu = sample.device.type == "npu"
674675
if isinstance(timestep, float):
675-
dtype = torch.float32 if is_mps else torch.float64
676+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
676677
else:
677-
dtype = torch.int32 if is_mps else torch.int64
678+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
678679
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
679680
elif len(timesteps.shape) == 0:
680681
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_union.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,10 +681,11 @@ def forward(
681681
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
682682
# This would be a good case for the `match` statement (Python 3.10+)
683683
is_mps = sample.device.type == "mps"
684+
is_npu = sample.device.type == "npu"
684685
if isinstance(timestep, float):
685-
dtype = torch.float32 if is_mps else torch.float64
686+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
686687
else:
687-
dtype = torch.int32 if is_mps else torch.int64
688+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
688689
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
689690
elif len(timesteps.shape) == 0:
690691
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_xs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,10 +1088,11 @@ def forward(
10881088
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
10891089
# This would be a good case for the `match` statement (Python 3.10+)
10901090
is_mps = sample.device.type == "mps"
1091+
is_npu = sample.device.type == "npu"
10911092
if isinstance(timestep, float):
1092-
dtype = torch.float32 if is_mps else torch.float64
1093+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
10931094
else:
1094-
dtype = torch.int32 if is_mps else torch.int64
1095+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
10951096
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
10961097
elif len(timesteps.shape) == 0:
10971098
timesteps = timesteps[None].to(sample.device)

0 commit comments

Comments
 (0)