Skip to content

Commit 8f09de9

Browse files
author
Han Wang
committed
support no-pbc-training
1 parent 25f1ea0 commit 8f09de9

File tree

3 files changed

+40
-18
lines changed

3 files changed

+40
-18
lines changed

source/train/Data.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def __init__ (self,
2525
self.type_map = self._load_type_map(sys_path)
2626
if self.type_map is not None:
2727
assert(len(self.type_map) >= max(self.atom_type)+1)
28+
# check pbc
29+
self.pbc = self._check_pbc(sys_path)
2830
# enforce type_map if necessary
2931
if type_map is not None and self.type_map is not None:
3032
atom_type_ = [type_map.index(self.type_map[ii]) for ii in self.atom_type]
@@ -348,6 +350,12 @@ def _load_type_map(self, sys_path) :
348350
else :
349351
return None
350352

353+
def _check_pbc(self, sys_path):
354+
pbc = True
355+
if os.path.isfile(os.path.join(sys_path, 'nopbc')) :
356+
pbc = False
357+
return pbc
358+
351359

352360
class DataSets (object):
353361
def __init__ (self,

source/train/DataSystem.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,24 @@ def _load_test(self, ntests = -1):
9595
for nn in test_system_data:
9696
self.test_data[nn].append(test_system_data[nn])
9797

98+
9899
def _make_default_mesh(self):
99100
self.default_mesh = []
100101
cell_size = np.max (self.rcut)
101102
for ii in range(self.nsystems) :
102-
test_system_data = self.data_systems[ii].get_batch(self.batch_size[ii])
103-
self.data_systems[ii].reset_get_batch()
104-
# test_system_data = self.data_systems[ii].get_test()
105-
avg_box = np.average (test_system_data["box"], axis = 0)
106-
avg_box = np.reshape (avg_box, [3,3])
107-
ncell = (np.linalg.norm(avg_box, axis=1)/ cell_size).astype(np.int32)
108-
ncell[ncell < 2] = 2
109-
default_mesh = np.zeros (6, dtype = np.int32)
110-
default_mesh[3:6] = ncell
111-
self.default_mesh.append(default_mesh)
103+
if self.data_systems[ii].pbc :
104+
test_system_data = self.data_systems[ii].get_batch(self.batch_size[ii])
105+
self.data_systems[ii].reset_get_batch()
106+
# test_system_data = self.data_systems[ii].get_test()
107+
avg_box = np.average (test_system_data["box"], axis = 0)
108+
avg_box = np.reshape (avg_box, [3,3])
109+
ncell = (np.linalg.norm(avg_box, axis=1)/ cell_size).astype(np.int32)
110+
ncell[ncell < 2] = 2
111+
default_mesh = np.zeros (6, dtype = np.int32)
112+
default_mesh[3:6] = ncell
113+
self.default_mesh.append(default_mesh)
114+
else:
115+
self.default_mesh.append(np.array([], dtype = np.int32))
112116

113117

114118
def compute_energy_shift(self, rcond = 1e-3, key = 'energy') :

source/train/DeepPot.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __init__(self,
5353
self.modifier_type = self.sess.run(t_modifier_type).decode('UTF-8')
5454
except ValueError:
5555
self.modifier_type = None
56+
except KeyError:
57+
self.modifier_type = None
5658
if self.modifier_type == 'dipole_charge':
5759
t_mdl_name = self.graph.get_tensor_by_name('load/modifier_attr/mdl_name:0')
5860
t_mdl_charge_map = self.graph.get_tensor_by_name('load/modifier_attr/mdl_charge_map:0')
@@ -108,9 +110,18 @@ def eval_inner(self,
108110
aparam = None,
109111
atomic = False) :
110112
# standarize the shape of inputs
111-
coords = np.array(coords)
112-
cells = np.array(cells)
113-
atom_types = np.array(atom_types, dtype = int)
113+
atom_types = np.array(atom_types, dtype = int).reshape([-1])
114+
natoms = atom_types.size
115+
coords = np.reshape(np.array(coords), [-1, natoms * 3])
116+
nframes = coords.shape[0]
117+
if cells is None:
118+
pbc = False
119+
# make cells to work around the requirement of pbc
120+
cells = np.tile(np.eye(3), [nframes, 1]).reshape([nframes, 9])
121+
else:
122+
pbc = True
123+
cells = np.array(cells).reshape([nframes, 9])
124+
114125
if self.has_fparam :
115126
assert(fparam is not None)
116127
fparam = np.array(fparam)
@@ -119,10 +130,6 @@ def eval_inner(self,
119130
aparam = np.array(aparam)
120131

121132
# reshape the inputs
122-
cells = np.reshape(cells, [-1, 9])
123-
nframes = cells.shape[0]
124-
coords = np.reshape(coords, [nframes, -1])
125-
natoms = coords.shape[1] // 3
126133
if self.has_fparam :
127134
fdim = self.get_dim_fparam()
128135
if fparam.size == nframes * fdim :
@@ -167,7 +174,10 @@ def eval_inner(self,
167174
for ii in range(nframes) :
168175
feed_dict_test[self.t_coord] = np.reshape(coords[ii:ii+1, :], [-1])
169176
feed_dict_test[self.t_box ] = np.reshape(cells [ii:ii+1, :], [-1])
170-
feed_dict_test[self.t_mesh ] = make_default_mesh(cells[ii:ii+1, :])
177+
if pbc:
178+
feed_dict_test[self.t_mesh ] = make_default_mesh(cells[ii:ii+1, :])
179+
else:
180+
feed_dict_test[self.t_mesh ] = np.array([], dtype = np.int32)
171181
if self.has_fparam:
172182
feed_dict_test[self.t_fparam] = np.reshape(fparam[ii:ii+1, :], [-1])
173183
if self.has_aparam:

0 commit comments

Comments
 (0)