We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 193570f commit 1dae1bcCopy full SHA for 1dae1bc
source/train/common.py
@@ -2,6 +2,7 @@
2
import numpy as np
3
import math
4
from deepmd.env import tf
5
+from deepmd.RunOptions import global_tf_float_precision
6
7
def gelu(x):
8
"""Gaussian Error Linear Unit.
@@ -172,3 +173,16 @@ def expand_sys_str(root_dir):
172
173
for filename in fnmatch.filter(filenames, 'type.raw'):
174
matches.append(root)
175
return matches
176
+
177
+def get_precision_func(precision):
178
+ if precision == 0:
179
+ return global_tf_float_precision
180
+ elif precision == 16:
181
+ return tf.float16
182
+ elif precision == 32:
183
+ return tf.float32
184
+ elif precision == 64:
185
+ return tf.float64
186
+ else:
187
+ raise RuntimeError("%d is not a valid precision" % precision)
188
0 commit comments