File tree Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Original file line number Diff line number Diff line change @@ -505,6 +505,13 @@ def module_is_offloaded(module):
505505            os .environ ["PT_HPU_MAX_COMPOUND_OP_SIZE" ] =  "1" 
506506            logger .debug ("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1" )
507507
508+             if  dtype  in  (torch .bfloat16 , None ) and  kwargs .pop ("sdp_on_bf16" , True ):
509+                 if  hasattr (torch ._C , "_set_math_sdp_allow_fp16_bf16_reduction" ):
510+                     torch ._C ._set_math_sdp_allow_fp16_bf16_reduction (True )
511+                     logger .warning (
512+                         "Enabled SDP with BF16 precision on HPU. To disable, please use `.to('hpu', sdp_on_bf16=False)`" 
513+                     )
514+ 
508515        module_names , _  =  self ._get_signature_keys (self )
509516        modules  =  [getattr (self , n , None ) for  n  in  module_names ]
510517        modules  =  [m  for  m  in  modules  if  isinstance (m , torch .nn .Module )]
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments