Skip to content

Commit afca35f

Browse files
authored
Update common.py
1 parent 886ca09 commit afca35f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

source/train/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,13 @@ def expand_sys_str(root_dir):
175175
return matches
176176

177177
def get_precision_func(precision):
178-
if precision == 0:
178+
if precision == "default":
179179
return global_tf_float_precision
180-
elif precision == 16:
180+
elif precision == "float16":
181181
return tf.float16
182-
elif precision == 32:
182+
elif precision == "float32":
183183
return tf.float32
184-
elif precision == 64:
184+
elif precision == "float64":
185185
return tf.float64
186186
else:
187187
raise RuntimeError("%d is not a valid precision" % precision)

0 commit comments

Comments
 (0)