Skip to content

Commit 746a799

Browse files
author
Han Wang
committed
implement test with dipole
1 parent 2a82848 commit 746a799

File tree

4 files changed

+54
-4
lines changed

4 files changed

+54
-4
lines changed

deepmd/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .env import set_mkl
2-
from .DeepEval import DeepEval
3-
from .DeepPot import DeepPot
4-
from .DeepPolar import DeepPolar
5-
from .DeepWFC import DeepWFC
2+
from .DeepEval import DeepEval
3+
from .DeepPot import DeepPot
4+
from .DeepDipole import DeepDipole
5+
from .DeepPolar import DeepPolar
6+
from .DeepWFC import DeepWFC
67

78
set_mkl()
89

source/scripts/freeze.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def _make_node_names(model_type = None) :
4040
nodes = "o_energy,o_force,o_virial,o_atom_energy,o_atom_virial,descrpt_attr/rcut,descrpt_attr/ntypes,fitting_attr/dfparam,fitting_attr/daparam,model_attr/tmap,model_attr/model_type"
4141
elif model_type == 'wfc':
4242
nodes = "o_wfc,descrpt_attr/rcut,descrpt_attr/ntypes,model_attr/tmap,model_attr/sel_type,model_attr/model_type"
43+
elif model_type == 'dipole':
44+
nodes = "o_dipole,descrpt_attr/rcut,descrpt_attr/ntypes,model_attr/tmap,model_attr/sel_type,model_attr/model_type"
4345
elif model_type == 'polar':
4446
nodes = "o_polar,descrpt_attr/rcut,descrpt_attr/ntypes,model_attr/tmap,model_attr/sel_type,model_attr/model_type"
4547
else:

source/train/DeepDipole.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/usr/bin/env python3
2+
3+
import os,sys
4+
import numpy as np
5+
from deepmd.DeepEval import DeepTensor
6+
7+
class DeepDipole (DeepTensor) :
8+
def __init__(self,
9+
model_file) :
10+
DeepTensor.__init__(self, model_file, 'dipole', 3)
11+

source/train/test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from deepmd.Data import DeepmdData
1111
from deepmd import DeepEval
1212
from deepmd import DeepPot
13+
from deepmd import DeepDipole
1314
from deepmd import DeepPolar
1415
from deepmd import DeepWFC
1516
from tensorflow.python.framework import ops
@@ -18,6 +19,8 @@ def test (args):
1819
de = DeepEval(args.model)
1920
if de.model_type == 'ener':
2021
test_ener(args)
22+
elif de.model_type == 'dipole':
23+
test_dipole(args)
2124
elif de.model_type == 'polar':
2225
test_polar(args)
2326
elif de.model_type == 'wfc':
@@ -154,3 +157,36 @@ def test_polar (args) :
154157
axis = 1)
155158
np.savetxt(detail_file+".out", pe,
156159
header = 'data_pxx data_pxy data_pxz data_pyx data_pyy data_pyz data_pzx data_pzy data_pzz pred_pxx pred_pxy pred_pxz pred_pyx pred_pyy pred_pyz pred_pzx pred_pzy pred_pzz')
160+
161+
162+
def test_dipole (args) :
163+
if args.rand_seed is not None :
164+
np.random.seed(args.rand_seed % (2**32))
165+
166+
dp = DeepDipole(args.model)
167+
data = DeepmdData(args.system, args.set_prefix, shuffle_test = args.shuffle_test)
168+
data.add('dipole', 3, atomic=True, must=True, high_prec=False, type_sel = dp.get_sel_type())
169+
test_data = data.get_test ()
170+
numb_test = args.numb_test
171+
natoms = len(test_data["type"][0])
172+
nframes = test_data["box"].shape[0]
173+
numb_test = min(nframes, numb_test)
174+
175+
coord = test_data["coord"][:numb_test].reshape([numb_test, -1])
176+
box = test_data["box"][:numb_test]
177+
atype = test_data["type"][0]
178+
dipole = dp.eval(coord, box, atype)
179+
180+
dipole = dipole.reshape([numb_test,-1])
181+
l2f = (l2err (dipole - test_data["dipole"] [:numb_test]))
182+
183+
print ("# number of test data : %d " % numb_test)
184+
print ("Dipole L2err : %e eV/A" % l2f)
185+
186+
detail_file = args.detail_file
187+
if detail_file is not None :
188+
pe = np.concatenate((np.reshape(test_data["dipole"][:numb_test], [-1,3]),
189+
np.reshape(dipole, [-1,3])),
190+
axis = 1)
191+
np.savetxt(detail_file+".out", pe,
192+
header = 'data_x data_y data_z pred_x pred_y pred_z')

0 commit comments

Comments
 (0)