Skip to content

Commit 7e6594a

Browse files
author
Han Wang
committed
simplifies the implementation of Model (WFC and Polar) and WFCLoss
1 parent 8eca960 commit 7e6594a

File tree

8 files changed

+141
-169
lines changed

8 files changed

+141
-169
lines changed

examples/water/train/polar.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"_comment": " model parameters",
44
"model":{
55
"type_map": ["O", "H"],
6+
"data_stat_nbatch": 1,
67
"descriptor": {
78
"type": "loc_frame",
89
"sel_a": [16, 32],

examples/water/train/polar_se_a.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"_comment": " model parameters",
44
"model":{
55
"type_map": ["O", "H"],
6+
"data_stat_nbatch": 1,
67
"descriptor" :{
78
"type": "se_a",
89
"sel": [46, 92],

examples/water/train/wannier.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"_comment": " model parameters",
44
"model":{
55
"type_map": ["O", "H"],
6+
"data_stat_nbatch": 1,
67
"descriptor": {
78
"type": "loc_frame",
89
"sel_a": [16, 32],

source/tests/test_wfc.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import dpdata,os,sys,json,unittest
2+
import numpy as np
3+
from deepmd.env import tf
4+
from common import Data,gen_data
5+
6+
from deepmd.RunOptions import RunOptions
7+
from deepmd.DataSystem import DataSystem
8+
from deepmd.DescrptLocFrame import DescrptLocFrame
9+
from deepmd.Fitting import WFCFitting
10+
from deepmd.Model import WFCModel
11+
from deepmd.common import j_must_have, j_must_have_d, j_have
12+
13+
global_ener_float_precision = tf.float64
14+
global_tf_float_precision = tf.float64
15+
global_np_float_precision = np.float64
16+
17+
class TestModel(unittest.TestCase):
18+
def setUp(self) :
19+
gen_data()
20+
21+
def test_model(self):
22+
jfile = 'wfc.json'
23+
with open(jfile) as fp:
24+
jdata = json.load (fp)
25+
run_opt = RunOptions(None)
26+
systems = j_must_have(jdata, 'systems')
27+
set_pfx = j_must_have(jdata, 'set_prefix')
28+
batch_size = j_must_have(jdata, 'batch_size')
29+
test_size = j_must_have(jdata, 'numb_test')
30+
batch_size = 1
31+
test_size = 1
32+
stop_batch = j_must_have(jdata, 'stop_batch')
33+
rcut = j_must_have (jdata['model']['descriptor'], 'rcut')
34+
35+
data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt = None)
36+
37+
test_data = data.get_test ()
38+
numb_test = 1
39+
40+
descrpt = DescrptLocFrame(jdata['model']['descriptor'])
41+
fitting = WFCFitting(jdata['model']['fitting_net'], descrpt)
42+
model = WFCModel(jdata['model'], descrpt, fitting)
43+
44+
input_data = {'coord' : [test_data['coord']],
45+
'box': [test_data['box']],
46+
'type': [test_data['type']],
47+
'natoms_vec' : [test_data['natoms_vec']],
48+
'default_mesh' : [test_data['default_mesh']],
49+
'fparam': [test_data['fparam']],
50+
}
51+
model._compute_dstats(input_data)
52+
53+
t_prop_c = tf.placeholder(tf.float32, [5], name='t_prop_c')
54+
t_energy = tf.placeholder(global_ener_float_precision, [None], name='t_energy')
55+
t_force = tf.placeholder(global_tf_float_precision, [None], name='t_force')
56+
t_virial = tf.placeholder(global_tf_float_precision, [None], name='t_virial')
57+
t_atom_ener = tf.placeholder(global_tf_float_precision, [None], name='t_atom_ener')
58+
t_coord = tf.placeholder(global_tf_float_precision, [None], name='i_coord')
59+
t_type = tf.placeholder(tf.int32, [None], name='i_type')
60+
t_natoms = tf.placeholder(tf.int32, [model.ntypes+2], name='i_natoms')
61+
t_box = tf.placeholder(global_tf_float_precision, [None, 9], name='i_box')
62+
t_mesh = tf.placeholder(tf.int32, [None], name='i_mesh')
63+
is_training = tf.placeholder(tf.bool)
64+
t_fparam = None
65+
66+
model_pred \
67+
= model.build (t_coord,
68+
t_type,
69+
t_natoms,
70+
t_box,
71+
t_mesh,
72+
t_fparam,
73+
suffix = "wfc",
74+
reuse = False)
75+
wfc = model_pred['wfc']
76+
77+
feed_dict_test = {t_prop_c: test_data['prop_c'],
78+
t_coord: np.reshape(test_data['coord'] [:numb_test, :], [-1]),
79+
t_box: test_data['box'] [:numb_test, :],
80+
t_type: np.reshape(test_data['type'] [:numb_test, :], [-1]),
81+
t_natoms: test_data['natoms_vec'],
82+
t_mesh: test_data['default_mesh'],
83+
is_training: False}
84+
85+
sess = tf.Session()
86+
sess.run(tf.global_variables_initializer())
87+
[p] = sess.run([wfc], feed_dict = feed_dict_test)
88+
89+
p = p.reshape([-1])
90+
refp = [-9.105016838228578990e-01,7.196284362034099935e-01,-9.548516928185298014e-02,2.764615027095288724e+00,2.661319598995644520e-01,7.579512949131941846e-02,-2.107409067376114997e+00,-1.299080016614967414e-01,-5.962778584850070285e-01,2.913899917663253514e-01,-1.226917174638697094e+00,1.829523069930876655e+00,1.015704024959750873e+00,-1.792333611099589386e-01,5.032898080485321834e-01,1.808561721292949453e-01,2.468863482075112081e+00,-2.566442546384765100e-01,-1.467453783795173994e-01,-1.822963931552128658e+00,5.843600156865462747e-01,-1.493875280832117403e+00,1.693322352814763398e-01,-1.877325443995481624e+00]
91+
92+
places = 6
93+
for ii in range(p.size) :
94+
self.assertAlmostEqual(p[ii], refp[ii], places = places)
95+
96+
97+

source/train/Fitting.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ def get_sel_type(self):
211211
def get_wfc_numb(self):
212212
return self.wfc_numb
213213

214+
def get_out_size(self):
215+
return self.wfc_numb * 3
216+
214217
def build (self,
215218
input_d,
216219
rot_mat,
@@ -283,6 +286,9 @@ def __init__ (self, jdata, descrpt) :
283286
def get_sel_type(self):
284287
return self.sel_type
285288

289+
def get_out_size(self):
290+
return 9
291+
286292
def build (self,
287293
input_d,
288294
rot_mat,
@@ -360,6 +366,9 @@ def __init__ (self, jdata, descrpt) :
360366
def get_sel_type(self):
361367
return self.sel_type
362368

369+
def get_out_size(self):
370+
return 9
371+
363372
def build (self,
364373
input_d,
365374
rot_mat,

source/train/Loss.py

Lines changed: 2 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -166,57 +166,8 @@ def print_on_training(self,
166166
if self.has_pf:
167167
print_str += prop_fmt % (np.sqrt(error_pf_test) / natoms[0], np.sqrt(error_pf_train) / natoms[0])
168168

169-
return print_str
170-
171-
172-
173-
class WFCLoss () :
174-
def __init__ (self, jdata, **kwarg) :
175-
model = kwarg['model']
176-
# data required
177-
add_data_requirement('wfc',
178-
model.get_wfc_numb() * 3,
179-
atomic=True,
180-
must=True,
181-
high_prec=False,
182-
type_sel = model.get_sel_type())
183-
184-
def build (self,
185-
learning_rate,
186-
natoms,
187-
model_dict,
188-
label_dict,
189-
suffix):
190-
wfc_hat = label_dict['wfc']
191-
wfc = model_dict['wfc']
192-
l2_loss = tf.reduce_mean( tf.square(wfc - wfc_hat), name='l2_'+suffix)
193-
self.l2_l = l2_loss
194-
more_loss = {}
195-
196-
return l2_loss, more_loss
197-
198-
def print_header(self) :
199-
prop_fmt = ' %9s %9s'
200-
print_str = ''
201-
print_str += prop_fmt % ('l2_tst', 'l2_trn')
202-
return print_str
203-
204-
def print_on_training(self,
205-
sess,
206-
natoms,
207-
feed_dict_test,
208-
feed_dict_batch) :
209-
error_test\
210-
= sess.run([self.l2_l], \
211-
feed_dict=feed_dict_test)
212-
error_train\
213-
= sess.run([self.l2_l], \
214-
feed_dict=feed_dict_batch)
215-
print_str = ""
216-
prop_fmt = " %9.2e %9.2e"
217-
print_str += prop_fmt % (np.sqrt(error_test), np.sqrt(error_train))
169+
return print_str
218170

219-
return print_str
220171

221172

222173
class TensorLoss () :
@@ -235,7 +186,7 @@ def __init__ (self, jdata, **kwarg) :
235186
atomic=True,
236187
must=True,
237188
high_prec=False,
238-
type_sel = model.get_sel_type())
189+
type_sel = type_sel)
239190

240191
def build (self,
241192
learning_rate,

0 commit comments

Comments
 (0)