Skip to content

Commit cab5d39

Browse files
committed
Merge remote-tracking branch 'origin/vace_22' into vace_22
2 parents deae16a + 41d962b commit cab5d39

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,13 @@ def module_is_offloaded(module):
505505
os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
506506
logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
507507

508+
if dtype in (torch.bfloat16, None) and kwargs.pop("sdp_on_bf16", True):
509+
if hasattr(torch._C, "_set_math_sdp_allow_fp16_bf16_reduction"):
510+
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
511+
logger.warning(
512+
"Enabled SDP with BF16 precision on HPU. To disable, please use `.to('hpu', sdp_on_bf16=False)`"
513+
)
514+
508515
module_names, _ = self._get_signature_keys(self)
509516
modules = [getattr(self, n, None) for n in module_names]
510517
modules = [m for m in modules if isinstance(m, torch.nn.Module)]

src/diffusers/pipelines/wan/pipeline_wan_vace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ def __call__(
814814
video,
815815
mask,
816816
reference_images,
817-
guidance_scale_2
817+
guidance_scale_2,
818818
)
819819

820820
if num_frames % self.vae_scale_factor_temporal != 1:

0 commit comments

Comments
 (0)