From 8e73b583724abb220f3104c6827e35ad9fd7725a Mon Sep 17 00:00:00 2001 From: Daniel Socek Date: Wed, 10 Sep 2025 14:19:32 +0000 Subject: [PATCH 1/2] Use SDP on BF16 in GPU/HPU migration Signed-off-by: Daniel Socek --- src/diffusers/pipelines/pipeline_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 0116ad917c00..a39188e4667f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -504,6 +504,11 @@ def module_is_offloaded(module): os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1" logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1") + if dtype in (torch.bfloat16, None) and kwargs.pop("sdp_on_bf16", True): + if hasattr(torch._C, "_set_math_sdp_allow_fp16_bf16_reduction"): + torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) + logger.warning(f"Enabled SDP with BF16 precision on HPU. To disable, please use `.to('hpu', sdp_on_bf16=False)`") + module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names] modules = [m for m in modules if isinstance(m, torch.nn.Module)] From b05259a9c6a7fc87bb5de80481fb7189883ac652 Mon Sep 17 00:00:00 2001 From: Daniel Socek Date: Fri, 12 Sep 2025 12:32:25 +0000 Subject: [PATCH 2/2] Formatting fix for enabling SDP with BF16 precision on HPU Signed-off-by: Daniel Socek --- src/diffusers/pipelines/pipeline_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a39188e4667f..151c2ccb4270 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -507,7 +507,9 @@ def module_is_offloaded(module): if dtype in (torch.bfloat16, None) and kwargs.pop("sdp_on_bf16", True): if hasattr(torch._C, "_set_math_sdp_allow_fp16_bf16_reduction"): torch._C._set_math_sdp_allow_fp16_bf16_reduction(True) - logger.warning(f"Enabled SDP with BF16 precision on HPU. To disable, please use `.to('hpu', sdp_on_bf16=False)`") + logger.warning( + "Enabled SDP with BF16 precision on HPU. To disable, please use `.to('hpu', sdp_on_bf16=False)`" + ) module_names, _ = self._get_signature_keys(self) modules = [getattr(self, n, None) for n in module_names]