@@ -661,7 +661,7 @@ def predict(self, dp):
661
661
662
662
Returns
663
663
-------
664
- labeled_sys LabeledSystem
664
+ labeled_sys : LabeledSystem
665
665
The labeled system.
666
666
"""
667
667
try :
@@ -679,19 +679,33 @@ def predict(self, dp):
679
679
680
680
labeled_sys = LabeledSystem ()
681
681
682
- for ss in self :
683
- coord = ss ['coords' ].reshape ((- 1 ,1 ))
684
- if not ss .nopbc :
685
- cell = ss ['cells' ].reshape ((- 1 ,1 ))
682
+ if 'auto_batch_size' not in DeepPot .__init__ .__code__ .co_varnames :
683
+ for ss in self :
684
+ coord = ss ['coords' ].reshape ((1 , ss .get_natoms ()* 3 ))
685
+ if not ss .nopbc :
686
+ cell = ss ['cells' ].reshape ((1 , 9 ))
687
+ else :
688
+ cell = None
689
+ e , f , v = dp .eval (coord , cell , atype )
690
+ data = ss .data
691
+ data ['energies' ] = e .reshape ((1 , 1 ))
692
+ data ['forces' ] = f .reshape ((1 , ss .get_natoms (), 3 ))
693
+ data ['virials' ] = v .reshape ((1 , 3 , 3 ))
694
+ this_sys = LabeledSystem .from_dict ({'data' : data })
695
+ labeled_sys .append (this_sys )
696
+ else :
697
+ # since v2.0.2, auto batch size is supported
698
+ coord = self .data ['coords' ].reshape ((self .get_nframes (), self .get_natoms ()* 3 ))
699
+ if not self .nopbc :
700
+ cell = self .data ['cells' ].reshape ((self .get_nframes (), 9 ))
686
701
else :
687
702
cell = None
688
703
e , f , v = dp .eval (coord , cell , atype )
689
- data = ss .data
690
- data ['energies' ] = e .reshape ((1 , 1 ))
691
- data ['forces' ] = f .reshape ((1 , - 1 , 3 ))
692
- data ['virials' ] = v .reshape ((1 , 3 , 3 ))
693
- this_sys = LabeledSystem .from_dict ({'data' : data })
694
- labeled_sys .append (this_sys )
704
+ data = self .data .copy ()
705
+ data ['energies' ] = e .reshape ((self .get_nframes (), 1 ))
706
+ data ['forces' ] = f .reshape ((self .get_nframes (), self .get_natoms (), 3 ))
707
+ data ['virials' ] = v .reshape ((self .get_nframes (), 3 , 3 ))
708
+ labeled_sys = LabeledSystem .from_dict ({'data' : data })
695
709
return labeled_sys
696
710
697
711
def pick_atom_idx (self , idx , nopbc = None ):
0 commit comments