Skip to content

Commit 32cf454

Browse files
author
Han Wang
committed
infer global polar
1 parent 7f8afdd commit 32cf454

File tree

4 files changed

+23
-5
lines changed

4 files changed

+23
-5
lines changed

deepmd/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .DeepPot import DeepPot
44
from .DeepDipole import DeepDipole
55
from .DeepPolar import DeepPolar
6+
from .DeepPolar import DeepGlobalPolar
67
from .DeepWFC import DeepWFC
78

89
set_mkl()

source/scripts/freeze.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def _make_node_names(model_type = None) :
4444
nodes = "o_dipole,descrpt_attr/rcut,descrpt_attr/ntypes,model_attr/tmap,model_attr/sel_type,model_attr/model_type"
4545
elif model_type == 'polar':
4646
nodes = "o_polar,descrpt_attr/rcut,descrpt_attr/ntypes,model_attr/tmap,model_attr/sel_type,model_attr/model_type"
47+
elif model_type == 'global_polar':
48+
nodes = "o_global_polar,descrpt_attr/rcut,descrpt_attr/ntypes,model_attr/tmap,model_attr/sel_type,model_attr/model_type"
4749
else:
4850
raise RuntimeError('unknow model type ' + model_type)
4951
return nodes

source/train/DeepEval.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def get_sel_type(self):
149149
def eval(self,
150150
coords,
151151
cells,
152-
atom_types) :
152+
atom_types,
153+
atomic = True) :
153154
# standarize the shape of inputs
154155
coords = np.array(coords)
155156
cells = np.array(cells)
@@ -183,10 +184,13 @@ def eval(self,
183184
tensor.append(v_out[0])
184185

185186
# reverse map of the outputs
186-
tensor = np.array(tensor)
187-
tensor = self.reverse_map(np.reshape(tensor, [nframes,-1,self.variable_dof]), sel_imap)
188-
189-
tensor = np.reshape(tensor, [nframes, len(sel_at), self.variable_dof])
187+
if atomic:
188+
tensor = np.array(tensor)
189+
tensor = self.reverse_map(np.reshape(tensor, [nframes,-1,self.variable_dof]), sel_imap)
190+
tensor = np.reshape(tensor, [nframes, len(sel_at), self.variable_dof])
191+
else:
192+
tensor = np.reshape(tensor, [nframes, self.variable_dof])
193+
190194
return tensor
191195

192196

source/train/DeepPolar.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,14 @@ def __init__(self,
99
model_file) :
1010
DeepTensor.__init__(self, model_file, 'polar', 9)
1111

12+
13+
class DeepGlobalPolar (DeepTensor) :
14+
def __init__(self,
15+
model_file) :
16+
DeepTensor.__init__(self, model_file, 'global_polar', 9)
17+
18+
def eval(self,
19+
coords,
20+
cells,
21+
atom_types) :
22+
return DeepTensor.eval(self, coords, cells, atom_types, atomic = False)

0 commit comments

Comments
 (0)