Skip to content

Commit 8d6fc08

Browse files
authored
Update Fitting.py
1 parent dec7a7e commit 8d6fc08

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

source/train/Fitting.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def build (self,
233233
else:
234234
outs = tf.concat([outs, final_layer], axis = 1)
235235

236-
return tf.cast(tf.reshape(outs, [-1]), self.fitting_precision)
236+
return tf.cast(tf.reshape(outs, [-1]), global_tf_float_precision)
237237

238238

239239
class WFCFitting () :
@@ -316,7 +316,7 @@ def build (self,
316316
outs = tf.concat([outs, final_layer], axis = 1)
317317
count += 1
318318

319-
return tf.cast(tf.reshape(outs, [-1]), self.fitting_precision)
319+
return tf.cast(tf.reshape(outs, [-1]), global_tf_float_precision)
320320

321321

322322

@@ -398,7 +398,7 @@ def build (self,
398398
outs = tf.concat([outs, final_layer], axis = 1)
399399
count += 1
400400

401-
return tf.cast(tf.reshape(outs, [-1]), self.fitting_precision)
401+
return tf.cast(tf.reshape(outs, [-1]), global_tf_float_precision)
402402

403403

404404
class PolarFittingSeA () :
@@ -530,7 +530,7 @@ def build (self,
530530
outs = tf.concat([outs, final_layer], axis = 1)
531531
count += 1
532532

533-
return tf.cast(tf.reshape(outs, [-1]), self.fitting_precision)
533+
return tf.cast(tf.reshape(outs, [-1]), global_tf_float_precision)
534534

535535

536536
class GlobalPolarFittingSeA () :
@@ -637,5 +637,5 @@ def build (self,
637637
outs = tf.concat([outs, final_layer], axis = 1)
638638
count += 1
639639

640-
return tf.cast(tf.reshape(outs, [-1]), self.fitting_precision)
640+
return tf.cast(tf.reshape(outs, [-1]), global_tf_float_precision)
641641
# return tf.reshape(outs, [tf.shape(inputs)[0] * natoms[0] * 3 // 3])

0 commit comments

Comments
 (0)