Skip to content

Commit 8e28cf8

Browse files
authored
Update Fitting.py
1 parent c0e0a14 commit 8e28cf8

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

source/train/Fitting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,14 @@ def build (self,
194194
if self.numb_fparam > 0 :
195195
ext_fparam = tf.tile(fparam, [1, natoms[2+type_i]])
196196
ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam])
197+
ext_fparam = tf.cast(ext_fparam,self.fitting_precision)
197198
layer = tf.concat([layer, ext_fparam], axis = 1)
198199
if self.numb_aparam > 0 :
199200
ext_aparam = tf.slice(aparam,
200201
[ 0, start_index * self.numb_aparam],
201202
[-1, natoms[2+type_i] * self.numb_aparam])
202203
ext_aparam = tf.reshape(ext_aparam, [-1, self.numb_aparam])
204+
ext_aparam = tf.cast(ext_aparam,self.fitting_precision)
203205
layer = tf.concat([layer, ext_aparam], axis = 1)
204206
start_index += natoms[2+type_i]
205207

0 commit comments

Comments
 (0)