@@ -233,7 +233,7 @@ def build (self,
233233 else :
234234 outs = tf .concat ([outs , final_layer ], axis = 1 )
235235
236- return tf .cast (tf .reshape (outs , [- 1 ]), self . fitting_precision )
236+ return tf .cast (tf .reshape (outs , [- 1 ]), global_tf_float_precision )
237237
238238
239239class WFCFitting () :
@@ -316,7 +316,7 @@ def build (self,
316316 outs = tf .concat ([outs , final_layer ], axis = 1 )
317317 count += 1
318318
319- return tf .cast (tf .reshape (outs , [- 1 ]), self . fitting_precision )
319+ return tf .cast (tf .reshape (outs , [- 1 ]), global_tf_float_precision )
320320
321321
322322
@@ -398,7 +398,7 @@ def build (self,
398398 outs = tf .concat ([outs , final_layer ], axis = 1 )
399399 count += 1
400400
401- return tf .cast (tf .reshape (outs , [- 1 ]), self . fitting_precision )
401+ return tf .cast (tf .reshape (outs , [- 1 ]), global_tf_float_precision )
402402
403403
404404class PolarFittingSeA () :
@@ -530,7 +530,7 @@ def build (self,
530530 outs = tf .concat ([outs , final_layer ], axis = 1 )
531531 count += 1
532532
533- return tf .cast (tf .reshape (outs , [- 1 ]), self . fitting_precision )
533+ return tf .cast (tf .reshape (outs , [- 1 ]), global_tf_float_precision )
534534
535535
536536class GlobalPolarFittingSeA () :
@@ -637,5 +637,5 @@ def build (self,
637637 outs = tf .concat ([outs , final_layer ], axis = 1 )
638638 count += 1
639639
640- return tf .cast (tf .reshape (outs , [- 1 ]), self . fitting_precision )
640+ return tf .cast (tf .reshape (outs , [- 1 ]), global_tf_float_precision )
641641 # return tf.reshape(outs, [tf.shape(inputs)[0] * natoms[0] * 3 // 3])
0 commit comments