File tree Expand file tree Collapse file tree 1 file changed +7
-0
lines changed
Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Original file line number Diff line number Diff line change @@ -779,6 +779,13 @@ def module_is_offloaded(module):
779779 os .environ ["PT_HPU_MAX_COMPOUND_OP_SIZE" ] = "1"
780780 logger .debug ("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1" )
781781
782+ if dtype in (torch .bfloat16 , None ) and kwargs .pop ("sdp_on_bf16" , True ):
783+ if hasattr (torch ._C , "_set_math_sdp_allow_fp16_bf16_reduction" ):
784+ torch ._C ._set_math_sdp_allow_fp16_bf16_reduction (True )
785+ logger .warning (
786+ "Enabled SDP with BF16 precision on HPU. To disable, please use `.to('hpu', sdp_on_bf16=False)`"
787+ )
788+
782789 module_names , _ = self ._get_signature_keys (self )
783790 modules = [getattr (self , n , None ) for n in module_names ]
784791 modules = [m for m in modules if isinstance (m , torch .nn .Module )]
You can’t perform that action at this time.
0 commit comments