Skip to content

Commit 38448be

Browse files
authored
Merge pull request #194 from GeiduanLiu/devel
user specified precision for embedding and fitting nets
2 parents 193570f + 0721f60 commit 38448be

File tree

5 files changed

+83
-54
lines changed

5 files changed

+83
-54
lines changed

source/train/DescrptSeA.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from deepmd.env import tf
3-
from deepmd.common import ClassArg, get_activation_func
3+
from deepmd.common import ClassArg, get_activation_func, get_precision
44
from deepmd.RunOptions import global_tf_float_precision
55
from deepmd.RunOptions import global_np_float_precision
66
from deepmd.env import op_module
@@ -19,7 +19,8 @@ def __init__ (self, jdata):
1919
.add('seed', int) \
2020
.add('exclude_types', list, default = []) \
2121
.add('set_davg_zero', bool, default = False) \
22-
.add('activation_function', str, default = 'tanh')
22+
.add('activation_function', str, default = 'tanh') \
23+
.add('precision', str, default = "default")
2324
class_data = args.parse(jdata)
2425
self.sel_a = class_data['sel']
2526
self.rcut_r = class_data['rcut']
@@ -30,6 +31,7 @@ def __init__ (self, jdata):
3031
self.seed = class_data['seed']
3132
self.trainable = class_data['trainable']
3233
self.filter_activation_fn = get_activation_func(class_data['activation_function'])
34+
self.filter_precision = get_precision(class_data['precision'])
3335
exclude_types = class_data['exclude_types']
3436
self.exclude_types = set()
3537
for tt in exclude_types:
@@ -247,7 +249,7 @@ def _pass_filter(self,
247249
[ 0, start_index* self.ndescrpt],
248250
[-1, natoms[2+type_i]* self.ndescrpt] )
249251
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
250-
layer, qmat = self._filter(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
252+
layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
251253
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
252254
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3])
253255
output.append(layer)
@@ -343,18 +345,18 @@ def _filter(self,
343345
for ii in range(1, len(outputs_size)):
344346
w = tf.get_variable('matrix_'+str(ii)+'_'+str(type_i),
345347
[outputs_size[ii - 1], outputs_size[ii]],
346-
global_tf_float_precision,
348+
self.filter_precision,
347349
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed),
348350
trainable = trainable)
349351
b = tf.get_variable('bias_'+str(ii)+'_'+str(type_i),
350352
[1, outputs_size[ii]],
351-
global_tf_float_precision,
353+
self.filter_precision,
352354
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
353355
trainable = trainable)
354356
if self.filter_resnet_dt :
355357
idt = tf.get_variable('idt_'+str(ii)+'_'+str(type_i),
356358
[1, outputs_size[ii]],
357-
global_tf_float_precision,
359+
self.filter_precision,
358360
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed),
359361
trainable = trainable)
360362
if outputs_size[ii] == outputs_size[ii-1]:
@@ -430,18 +432,18 @@ def _filter_type_ext(self,
430432
for ii in range(1, len(outputs_size)):
431433
w = tf.get_variable('matrix_'+str(ii)+'_'+str(type_i),
432434
[outputs_size[ii - 1], outputs_size[ii]],
433-
global_tf_float_precision,
435+
self.filter_precision,
434436
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed),
435437
trainable = trainable)
436438
b = tf.get_variable('bias_'+str(ii)+'_'+str(type_i),
437439
[1, outputs_size[ii]],
438-
global_tf_float_precision,
440+
self.filter_precision,
439441
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
440442
trainable = trainable)
441443
if self.filter_resnet_dt :
442444
idt = tf.get_variable('idt_'+str(ii)+'_'+str(type_i),
443445
[1, outputs_size[ii]],
444-
global_tf_float_precision,
446+
self.filter_precision,
445447
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed),
446448
trainable = trainable)
447449
if outputs_size[ii] == outputs_size[ii-1]:

source/train/DescrptSeR.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from deepmd.env import tf
3-
from deepmd.common import ClassArg, get_activation_func
3+
from deepmd.common import ClassArg, get_activation_func, get_precision
44
from deepmd.RunOptions import global_tf_float_precision
55
from deepmd.RunOptions import global_np_float_precision
66
from deepmd.env import op_module
@@ -18,7 +18,8 @@ def __init__ (self, jdata):
1818
.add('seed', int) \
1919
.add('exclude_types', list, default = []) \
2020
.add('set_davg_zero', bool, default = False) \
21-
.add("activation_function", str, default = "tanh")
21+
.add("activation_function", str, default = "tanh") \
22+
.add("precision", str, default = "default")
2223
class_data = args.parse(jdata)
2324
self.sel_r = class_data['sel']
2425
self.rcut = class_data['rcut']
@@ -27,7 +28,8 @@ def __init__ (self, jdata):
2728
self.filter_resnet_dt = class_data['resnet_dt']
2829
self.seed = class_data['seed']
2930
self.trainable = class_data['trainable']
30-
self.filter_activation_fn = get_activation_func(class_data["activation_function"])
31+
self.filter_activation_fn = get_activation_func(class_data["activation_function"])
32+
self.filter_precision = get_precision(class_data['precision'])
3133
exclude_types = class_data['exclude_types']
3234
self.exclude_types = set()
3335
for tt in exclude_types:
@@ -206,7 +208,7 @@ def _pass_filter(self,
206208
[ 0, start_index* self.ndescrpt],
207209
[-1, natoms[2+type_i]* self.ndescrpt] )
208210
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
209-
layer = self._filter_r(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
211+
layer = self._filter_r(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
210212
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
211213
output.append(layer)
212214
start_index += natoms[2+type_i]
@@ -288,18 +290,18 @@ def _filter_r(self,
288290
for ii in range(1, len(outputs_size)):
289291
w = tf.get_variable('matrix_'+str(ii)+'_'+str(type_i),
290292
[outputs_size[ii - 1], outputs_size[ii]],
291-
global_tf_float_precision,
293+
self.filter_precision,
292294
tf.random_normal_initializer(stddev=stddev/np.sqrt(outputs_size[ii]+outputs_size[ii-1]), seed = seed),
293295
trainable = trainable)
294296
b = tf.get_variable('bias_'+str(ii)+'_'+str(type_i),
295297
[1, outputs_size[ii]],
296-
global_tf_float_precision,
298+
self.filter_precision,
297299
tf.random_normal_initializer(stddev=stddev, mean = bavg, seed = seed),
298300
trainable = trainable)
299301
if self.filter_resnet_dt :
300302
idt = tf.get_variable('idt_'+str(ii)+'_'+str(type_i),
301303
[1, outputs_size[ii]],
302-
global_tf_float_precision,
304+
self.filter_precision,
303305
tf.random_normal_initializer(stddev=0.001, mean = 1.0, seed = seed),
304306
trainable = trainable)
305307
if outputs_size[ii] == outputs_size[ii-1]:

0 commit comments

Comments
 (0)