Skip to content

Commit 6a1b1fb

Browse files
authored
Update common.py
1 parent 6efe946 commit 6a1b1fb

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

source/train/common.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,28 @@
11
import os,warnings,fnmatch
22
import numpy as np
3+
import math
34
from deepmd.env import tf
45

6+
def gelu(x):
7+
"""Gaussian Error Linear Unit.
8+
This is a smoother version of the RELU.
9+
Original paper: https://arxiv.org/abs/1606.08415
10+
Args:
11+
x: float Tensor to perform activation.
12+
Returns:
13+
`x` with the GELU activation applied.
14+
"""
15+
cdf = 0.5 * (1.0 + tf.tanh((math.sqrt(2 / math.pi) * (x + 0.044715 * tf.pow(x, 3)))))
16+
return x * cdf
17+
518
data_requirement = {}
619
activation_fn_dict = {
720
"relu": tf.nn.relu,
821
"relu6": tf.nn.relu6,
922
"softplus": tf.nn.softplus,
1023
"sigmoid": tf.sigmoid,
11-
"tanh": tf.nn.tanh
24+
"tanh": tf.nn.tanh,
25+
"gelu": gelu
1226
}
1327
def add_data_requirement(key,
1428
ndof,

0 commit comments

Comments
 (0)