File tree Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Original file line number Diff line number Diff line change @@ -505,6 +505,13 @@ def module_is_offloaded(module):
505505 os .environ ["PT_HPU_MAX_COMPOUND_OP_SIZE" ] = "1"
506506 logger .debug ("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1" )
507507
508+ if dtype in (torch .bfloat16 , None ) and kwargs .pop ("sdp_on_bf16" , True ):
509+ if hasattr (torch ._C , "_set_math_sdp_allow_fp16_bf16_reduction" ):
510+ torch ._C ._set_math_sdp_allow_fp16_bf16_reduction (True )
511+ logger .warning (
512+ "Enabled SDP with BF16 precision on HPU. To disable, please use `.to('hpu', sdp_on_bf16=False)`"
513+ )
514+
508515 module_names , _ = self ._get_signature_keys (self )
509516 modules = [getattr (self , n , None ) for n in module_names ]
510517 modules = [m for m in modules if isinstance (m , torch .nn .Module )]
You can’t perform that action at this time.
0 commit comments