@@ -53,6 +53,8 @@ def __init__(self,
5353 self .modifier_type = self .sess .run (t_modifier_type ).decode ('UTF-8' )
5454 except ValueError :
5555 self .modifier_type = None
56+ except KeyError :
57+ self .modifier_type = None
5658 if self .modifier_type == 'dipole_charge' :
5759 t_mdl_name = self .graph .get_tensor_by_name ('load/modifier_attr/mdl_name:0' )
5860 t_mdl_charge_map = self .graph .get_tensor_by_name ('load/modifier_attr/mdl_charge_map:0' )
@@ -108,9 +110,18 @@ def eval_inner(self,
108110 aparam = None ,
109111 atomic = False ) :
110112 # standarize the shape of inputs
111- coords = np .array (coords )
112- cells = np .array (cells )
113- atom_types = np .array (atom_types , dtype = int )
113+ atom_types = np .array (atom_types , dtype = int ).reshape ([- 1 ])
114+ natoms = atom_types .size
115+ coords = np .reshape (np .array (coords ), [- 1 , natoms * 3 ])
116+ nframes = coords .shape [0 ]
117+ if cells is None :
118+ pbc = False
119+ # make cells to work around the requirement of pbc
120+ cells = np .tile (np .eye (3 ), [nframes , 1 ]).reshape ([nframes , 9 ])
121+ else :
122+ pbc = True
123+ cells = np .array (cells ).reshape ([nframes , 9 ])
124+
114125 if self .has_fparam :
115126 assert (fparam is not None )
116127 fparam = np .array (fparam )
@@ -119,10 +130,6 @@ def eval_inner(self,
119130 aparam = np .array (aparam )
120131
121132 # reshape the inputs
122- cells = np .reshape (cells , [- 1 , 9 ])
123- nframes = cells .shape [0 ]
124- coords = np .reshape (coords , [nframes , - 1 ])
125- natoms = coords .shape [1 ] // 3
126133 if self .has_fparam :
127134 fdim = self .get_dim_fparam ()
128135 if fparam .size == nframes * fdim :
@@ -167,7 +174,10 @@ def eval_inner(self,
167174 for ii in range (nframes ) :
168175 feed_dict_test [self .t_coord ] = np .reshape (coords [ii :ii + 1 , :], [- 1 ])
169176 feed_dict_test [self .t_box ] = np .reshape (cells [ii :ii + 1 , :], [- 1 ])
170- feed_dict_test [self .t_mesh ] = make_default_mesh (cells [ii :ii + 1 , :])
177+ if pbc :
178+ feed_dict_test [self .t_mesh ] = make_default_mesh (cells [ii :ii + 1 , :])
179+ else :
180+ feed_dict_test [self .t_mesh ] = np .array ([], dtype = np .int32 )
171181 if self .has_fparam :
172182 feed_dict_test [self .t_fparam ] = np .reshape (fparam [ii :ii + 1 , :], [- 1 ])
173183 if self .has_aparam :
0 commit comments