1818from  torch  import  nn 
1919
2020from  ..utils  import  deprecate 
21- from  ..utils .import_utils  import  is_torch_npu_available 
21+ from  ..utils .import_utils  import  is_torch_npu_available ,  is_torch_version 
2222
2323
2424if  is_torch_npu_available ():
@@ -79,10 +79,10 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b
7979        self .approximate  =  approximate 
8080
8181    def  gelu (self , gate : torch .Tensor ) ->  torch .Tensor :
82-         if  gate .device .type  !=  "mps" :
83-             return   F . gelu ( gate ,  approximate = self . approximate ) 
84-         # mps: gelu is not implemented for float16 
85-         return  F .gelu (gate . to ( dtype = torch . float32 ) , approximate = self .approximate ). to ( dtype = gate . dtype )
82+         if  gate .device .type  ==  "mps"   and   is_torch_version ( "<" ,  "2.0.0" ) :
83+             # fp16  gelu not supported on mps before torch 2.0 
84+              return   F . gelu ( gate . to ( dtype = torch . float32 ),  approximate = self . approximate ). to ( dtype = gate . dtype ) 
85+         return  F .gelu (gate , approximate = self .approximate )
8686
8787    def  forward (self , hidden_states ):
8888        hidden_states  =  self .proj (hidden_states )
@@ -105,10 +105,10 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
105105        self .proj  =  nn .Linear (dim_in , dim_out  *  2 , bias = bias )
106106
107107    def  gelu (self , gate : torch .Tensor ) ->  torch .Tensor :
108-         if  gate .device .type  !=  "mps" :
109-             return   F . gelu ( gate ) 
110-         # mps: gelu is not implemented for float16 
111-         return  F .gelu (gate . to ( dtype = torch . float32 )). to ( dtype = gate . dtype )
108+         if  gate .device .type  ==  "mps"   and   is_torch_version ( "<" ,  "2.0.0" ) :
109+             # fp16  gelu not supported on mps before torch 2.0 
110+              return   F . gelu ( gate . to ( dtype = torch . float32 )). to ( dtype = gate . dtype ) 
111+         return  F .gelu (gate )
112112
113113    def  forward (self , hidden_states , * args , ** kwargs ):
114114        if  len (args ) >  0  or  kwargs .get ("scale" , None ) is  not None :
0 commit comments