Skip to content

Commit 854b89b

Browse files
author
Han Wang
committed
auto convert input to v1 compatibility
1 parent 48b60f9 commit 854b89b

File tree

3 files changed

+144
-1
lines changed

3 files changed

+144
-1
lines changed

source/train/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
configure_file("RunOptions.py.in" "${CMAKE_CURRENT_BINARY_DIR}/RunOptions.py" @ONLY)
44

5-
file(GLOB LIB_PY main.py common.py env.py Network.py Deep*.py Data.py DataSystem.py Model*.py Descrpt*.py Fitting.py Loss.py LearningRate.py Trainer.py TabInter.py ${CMAKE_CURRENT_BINARY_DIR}/RunOptions.py)
5+
file(GLOB LIB_PY main.py common.py env.py compat.py Network.py Deep*.py Data.py DataSystem.py Model*.py Descrpt*.py Fitting.py Loss.py LearningRate.py Trainer.py TabInter.py ${CMAKE_CURRENT_BINARY_DIR}/RunOptions.py)
66

77
file(GLOB CLS_PY Local.py Slurm.py)
88

source/train/compat.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import os,json,warnings
2+
from deepmd.common import j_have,j_must_have,j_must_have_d
3+
4+
def convert_input_v0_v1(jdata, warning = True, dump = None) :
5+
output = {}
6+
if 'with_distrib' in jdata:
7+
output['with_distrib'] = jdata['with_distrib']
8+
if jdata['use_smooth'] :
9+
output['model'] = _smth_model(jdata)
10+
else:
11+
output['model'] = _nonsmth_model(jdata)
12+
output['learning_rate'] = _learning_rate(jdata)
13+
output['loss'] = _loss(jdata)
14+
output['training'] = _training(jdata)
15+
_warnning_input_v0_v1(dump)
16+
if dump is not None:
17+
with open(dump, 'w') as fp:
18+
json.dump(output, fp, indent=4)
19+
return output
20+
21+
def _warnning_input_v0_v1(fname) :
22+
msg = 'It seems that you are using a deepmd-kit input of version 0.x.x, which is deprecated. we have converted the input to >1.0.0 compatible'
23+
if fname is not None:
24+
msg += ', and output it to file ' + fname
25+
warnings.warn(msg)
26+
27+
def _nonsmth_model(jdata):
28+
model = {}
29+
model['descriptor'] = _nonsmth_descriptor(jdata)
30+
model['fitting_net'] = _fitting_net(jdata)
31+
return model
32+
33+
def _smth_model(jdata):
34+
model = {}
35+
model['descriptor'] = _smth_descriptor(jdata)
36+
model['fitting_net'] = _fitting_net(jdata)
37+
return model
38+
39+
def _nonsmth_descriptor(jdata) :
40+
output = {}
41+
seed = None
42+
if j_have (jdata, 'seed') :
43+
seed = jdata['seed']
44+
# model
45+
descriptor = {}
46+
descriptor['type'] = 'loc_frame'
47+
descriptor['sel_a'] = jdata['sel_a']
48+
descriptor['sel_r'] = jdata['sel_r']
49+
descriptor['rcut'] = jdata['rcut']
50+
descriptor['axis_rule'] = jdata['axis_rule']
51+
return descriptor
52+
53+
def _smth_descriptor(jdata):
54+
descriptor = {}
55+
seed = None
56+
if j_have (jdata, 'seed') :
57+
seed = jdata['seed']
58+
descriptor['type'] = 'se_a'
59+
descriptor['sel'] = jdata['sel_a']
60+
descriptor['rcut'] = jdata['rcut']
61+
if j_have(jdata, 'rcut_smth') :
62+
descriptor['rcut_r_smth'] = jdata['rcut_smth']
63+
else :
64+
descriptor['rcut_r_smth'] = descriptor['rcut']
65+
descriptor['neuron'] = j_must_have (jdata, 'filter_neuron')
66+
descriptor['axis_neuron'] = j_must_have_d (jdata, 'axis_neuron', ['n_axis_neuron'])
67+
descriptor['resnet_dt'] = False
68+
if j_have(jdata, 'resnet_dt') :
69+
descriptor['resnet_dt'] = jdata['filter_resnet_dt']
70+
if seed is not None:
71+
descriptor['seed'] = seed
72+
return descriptor
73+
74+
def _fitting_net(jdata):
75+
fitting_net = {}
76+
seed = None
77+
if j_have (jdata, 'seed') :
78+
seed = jdata['seed']
79+
fitting_net['neuron']= j_must_have_d (jdata, 'fitting_neuron', ['n_neuron'])
80+
fitting_net['resnet_dt'] = True
81+
if j_have(jdata, 'resnet_dt') :
82+
fitting_net['resnet_dt'] = jdata['resnet_dt']
83+
if j_have(jdata, 'fitting_resnet_dt') :
84+
fitting_net['resnet_dt'] = jdata['fitting_resnet_dt']
85+
if seed is not None:
86+
fitting_net['seed'] = seed
87+
return fitting_net
88+
89+
def _learning_rate(jdata):
90+
# learning rate
91+
learning_rate = {}
92+
learning_rate['type'] = 'exp'
93+
learning_rate['decay_steps'] = j_must_have(jdata, 'decay_steps')
94+
learning_rate['decay_rate'] = j_must_have(jdata, 'decay_rate')
95+
learning_rate['start_lr'] = j_must_have(jdata, 'start_lr')
96+
return learning_rate
97+
98+
def _loss(jdata):
99+
# loss
100+
loss = {}
101+
loss['start_pref_e'] = j_must_have (jdata, 'start_pref_e')
102+
loss['limit_pref_e'] = j_must_have (jdata, 'limit_pref_e')
103+
loss['start_pref_f'] = j_must_have (jdata, 'start_pref_f')
104+
loss['limit_pref_f'] = j_must_have (jdata, 'limit_pref_f')
105+
loss['start_pref_v'] = j_must_have (jdata, 'start_pref_v')
106+
loss['limit_pref_v'] = j_must_have (jdata, 'limit_pref_v')
107+
if j_have(jdata, 'start_pref_ae') :
108+
loss['start_pref_ae'] = jdata['start_pref_ae']
109+
if j_have(jdata, 'limit_pref_ae') :
110+
loss['limit_pref_ae'] = jdata['limit_pref_ae']
111+
return loss
112+
113+
def _training(jdata):
114+
# training
115+
training = {}
116+
seed = None
117+
if j_have (jdata, 'seed') :
118+
seed = jdata['seed']
119+
training['systems'] = jdata['systems']
120+
training['set_prefix'] = jdata['set_prefix']
121+
training['stop_batch'] = jdata['stop_batch']
122+
training['batch_size'] = jdata['batch_size']
123+
if seed is not None:
124+
training['seed'] = seed
125+
training['disp_file'] = "lcurve.out"
126+
if j_have (jdata, "disp_file") : training['disp_file'] = jdata["disp_file"]
127+
training['disp_freq'] = j_must_have (jdata, 'disp_freq')
128+
training['numb_test'] = j_must_have (jdata, 'numb_test')
129+
training['save_freq'] = j_must_have (jdata, 'save_freq')
130+
training['save_ckpt'] = j_must_have (jdata, 'save_ckpt')
131+
training['display_in_training'] = j_must_have (jdata, 'disp_training')
132+
training['timing_in_training'] = j_must_have (jdata, 'time_training')
133+
training['profiling'] = False
134+
if j_have (jdata, 'profiling') :
135+
training['profiling'] = jdata['profiling']
136+
if training['profiling'] :
137+
training['profiling_file'] = j_must_have (jdata, 'profiling_file')
138+
return training

source/train/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import argparse
88
import json
99
from deepmd.env import tf
10+
from deepmd.compat import convert_input_v0_v1
1011

1112
lib_path = os.path.dirname(os.path.realpath(__file__)) + "/../lib/"
1213
sys.path.append (lib_path)
@@ -54,6 +55,10 @@ def train (args) :
5455
# load json database
5556
fp = open (args.INPUT, 'r')
5657
jdata = json.load (fp)
58+
if not 'model' in jdata.keys():
59+
jdata = convert_input_v0_v1(jdata,
60+
warning = True,
61+
dump = 'input_v1_compat.json')
5762
# run options
5863
with_distrib = False
5964
if 'with_distrib' in jdata:

0 commit comments

Comments
 (0)