Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit fecd3e1

Browse files
bgawrychBartlomiej Gawrych
andauthored
Use tanh approximation in gelu (#1590)
Co-authored-by: Bartlomiej Gawrych <[email protected]>
1 parent ce0e0a2 commit fecd3e1

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/gluonnlp/layers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
InitializerType = Optional[Union[mx.init.Initializer, str]]
3434

35+
GELU_TANH_SUPPORT = 'gelu_tanh' in mx.symbol.LeakyReLU.__doc__
3536

3637
@use_np
3738
def get_norm_layer(normalization: str = 'layer_norm',
@@ -322,8 +323,11 @@ def forward(self, x):
322323
if self._mode == 'erf':
323324
return npx.leaky_relu(x, act_type='gelu')
324325
elif self._mode == 'tanh':
325-
return 0.5 * x\
326-
* (1.0 + np.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * (x ** 3))))
326+
if GELU_TANH_SUPPORT:
327+
return npx.leaky_relu(x, act_type='gelu_tanh')
328+
else:
329+
return 0.5 * x\
330+
* (1.0 + np.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * (x ** 3))))
327331
elif self._mode == 'sigmoid':
328332
return x * npx.sigmoid(1.702 * x)
329333
else:

0 commit comments

Comments
 (0)