1818from  torch  import  nn 
1919
2020from  ..utils  import  deprecate 
21- from  ..utils .import_utils  import  is_torch_npu_available 
21+ from  ..utils .import_utils  import  is_torch_npu_available ,  is_torch_version 
2222
2323
2424if  is_torch_npu_available ():
@@ -79,7 +79,7 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b
7979        self .approximate  =  approximate 
8080
8181    def  gelu (self , gate : torch .Tensor ) ->  torch .Tensor :
82-         if  gate .device .type  ==  "mps"  and  torch . __version__   <   ' 2.0.0' 
82+         if  gate .device .type  ==  "mps"  and  is_torch_version ( "<" ,  " 2.0.0" ) :
8383            # fp16 gelu not supported on mps before torch 2.0 
8484            return  F .gelu (gate .to (dtype = torch .float32 ), approximate = self .approximate ).to (dtype = gate .dtype )
8585        return  F .gelu (gate , approximate = self .approximate )
@@ -105,7 +105,7 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
105105        self .proj  =  nn .Linear (dim_in , dim_out  *  2 , bias = bias )
106106
107107    def  gelu (self , gate : torch .Tensor ) ->  torch .Tensor :
108-         if  gate .device .type  ==  "mps"  and  torch . __version__   <   ' 2.0.0' 
108+         if  gate .device .type  ==  "mps"  and  is_torch_version ( "<" ,  " 2.0.0" ) :
109109            # fp16 gelu not supported on mps before torch 2.0 
110110            return  F .gelu (gate .to (dtype = torch .float32 )).to (dtype = gate .dtype )
111111        return  F .gelu (gate )
0 commit comments