1818from torch import nn
1919
2020from ..utils import deprecate
21- from ..utils .import_utils import is_torch_npu_available
21+ from ..utils .import_utils import is_torch_npu_available , is_torch_version
2222
2323
2424if is_torch_npu_available ():
@@ -79,10 +79,10 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b
7979 self .approximate = approximate
8080
8181 def gelu (self , gate : torch .Tensor ) -> torch .Tensor :
82- if gate .device .type != "mps" :
83- return F . gelu ( gate , approximate = self . approximate )
84- # mps: gelu is not implemented for float16
85- return F .gelu (gate . to ( dtype = torch . float32 ) , approximate = self .approximate ). to ( dtype = gate . dtype )
82+ if gate .device .type == "mps" and is_torch_version ( "<" , "2.0.0" ) :
83+ # fp16 gelu not supported on mps before torch 2.0
84+ return F . gelu ( gate . to ( dtype = torch . float32 ), approximate = self . approximate ). to ( dtype = gate . dtype )
85+ return F .gelu (gate , approximate = self .approximate )
8686
8787 def forward (self , hidden_states ):
8888 hidden_states = self .proj (hidden_states )
@@ -105,10 +105,10 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
105105 self .proj = nn .Linear (dim_in , dim_out * 2 , bias = bias )
106106
107107 def gelu (self , gate : torch .Tensor ) -> torch .Tensor :
108- if gate .device .type != "mps" :
109- return F . gelu ( gate )
110- # mps: gelu is not implemented for float16
111- return F .gelu (gate . to ( dtype = torch . float32 )). to ( dtype = gate . dtype )
108+ if gate .device .type == "mps" and is_torch_version ( "<" , "2.0.0" ) :
109+ # fp16 gelu not supported on mps before torch 2.0
110+ return F . gelu ( gate . to ( dtype = torch . float32 )). to ( dtype = gate . dtype )
111+ return F .gelu (gate )
112112
113113 def forward (self , hidden_states , * args , ** kwargs ):
114114 if len (args ) > 0 or kwargs .get ("scale" , None ) is not None :
0 commit comments