@@ -839,6 +839,44 @@ def shuffle(self):
839
839
self .data [ii ] = self .data [ii ][idx ]
840
840
return idx
841
841
842
+ def predict (self , dp ):
843
+ """
844
+ Predict energies and forces by deepmd-kit.
845
+
846
+ Parameters
847
+ ----------
848
+ dp : deepmd.DeepPot or str
849
+ The deepmd-kit potential class or the filename of the model.
850
+
851
+ Returns
852
+ -------
853
+ labeled_sys LabeledSystem
854
+ The labeled system.
855
+ """
856
+ import deepmd .DeepPot as DeepPot
857
+ if not isinstance (dp , DeepPot ):
858
+ dp = DeepPot (dp )
859
+ type_map = dp .get_type_map ()
860
+ ori_sys = self .copy ()
861
+ ori_sys .sort_atom_names (type_map = type_map )
862
+ atype = ori_sys ['atom_types' ]
863
+
864
+ labeled_sys = LabeledSystem ()
865
+
866
+ for ss in self :
867
+ coord = ss ['coords' ].reshape ((- 1 ,1 ))
868
+ if not ss .nopbc :
869
+ cell = ss ['cells' ].reshape ((- 1 ,1 ))
870
+ else :
871
+ cell = None
872
+ e , f , v = dp .eval (coord , cell , atype )
873
+ data = ss .data
874
+ data ['energies' ] = e .reshape ((1 , 1 ))
875
+ data ['forces' ] = f .reshape ((1 , - 1 , 3 ))
876
+ data ['virials' ] = v .reshape ((1 , 3 , 3 ))
877
+ this_sys = LabeledSystem .from_dict ({'data' : data })
878
+ labeled_sys .append (this_sys )
879
+ return labeled_sys
842
880
843
881
def get_cell_perturb_matrix (cell_pert_fraction ):
844
882
if cell_pert_fraction < 0 :
0 commit comments