Skip to content

Commit 1dae1bc

Browse files
authored
Update common.py
1 parent 193570f commit 1dae1bc

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

source/train/common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import math
44
from deepmd.env import tf
5+
from deepmd.RunOptions import global_tf_float_precision
56

67
def gelu(x):
78
"""Gaussian Error Linear Unit.
@@ -172,3 +173,16 @@ def expand_sys_str(root_dir):
172173
for filename in fnmatch.filter(filenames, 'type.raw'):
173174
matches.append(root)
174175
return matches
176+
177+
def get_precision_func(precision):
178+
if precision == 0:
179+
return global_tf_float_precision
180+
elif precision == 16:
181+
return tf.float16
182+
elif precision == 32:
183+
return tf.float32
184+
elif precision == 64:
185+
return tf.float64
186+
else:
187+
raise RuntimeError("%d is not a valid precision" % precision)
188+

0 commit comments

Comments
 (0)