Skip to content

Commit 8a238c3

Browse files
Merge branch 'huggingface:main' into main
2 parents bb41c2b + f5c113e commit 8a238c3

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,13 @@ def module_is_offloaded(module):
779779
os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
780780
logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
781781

782+
if dtype in (torch.bfloat16, None) and kwargs.pop("sdp_on_bf16", True):
783+
if hasattr(torch._C, "_set_math_sdp_allow_fp16_bf16_reduction"):
784+
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
785+
logger.warning(
786+
"Enabled SDP with BF16 precision on HPU. To disable, please use `.to('hpu', sdp_on_bf16=False)`"
787+
)
788+
782789
module_names, _ = self._get_signature_keys(self)
783790
modules = [getattr(self, n, None) for n in module_names]
784791
modules = [m for m in modules if isinstance(m, torch.nn.Module)]

0 commit comments

Comments
 (0)