@@ -36,13 +36,19 @@ def test_ener (args) :
3636
3737 dp = DeepPot (args .model )
3838 data = DeepmdData (args .system , args .set_prefix , shuffle_test = args .shuffle_test , type_map = dp .get_type_map ())
39+ data .add ('energy' , 1 , atomic = False , must = False , high_prec = True )
40+ data .add ('force' , 3 , atomic = True , must = False , high_prec = False )
41+ data .add ('virial' , 9 , atomic = False , must = False , high_prec = False )
42+ if dp .get_dim_fparam () > 0 :
43+ data .add ('fparam' , dp .get_dim_fparam (), atomic = False , must = True , high_prec = False )
44+ if dp .get_dim_aparam () > 0 :
45+ data .add ('aparam' , dp .get_dim_aparam (), atomic = True , must = True , high_prec = False )
3946
4047 test_data = data .get_test ()
41- numb_test = args .numb_test
4248 natoms = len (test_data ["type" ][0 ])
4349 nframes = test_data ["box" ].shape [0 ]
50+ numb_test = args .numb_test
4451 numb_test = min (nframes , numb_test )
45-
4652 coord = test_data ["coord" ][:numb_test ].reshape ([numb_test , - 1 ])
4753 box = test_data ["box" ][:numb_test ]
4854 atype = test_data ["type" ][0 ]
@@ -54,6 +60,7 @@ def test_ener (args) :
5460 aparam = test_data ["aparam" ][:numb_test ]
5561 else :
5662 aparam = None
63+
5764 energy , force , virial , ae , av = dp .eval (coord , box , atype , fparam = fparam , aparam = aparam , atomic = True )
5865 energy = energy .reshape ([numb_test ,1 ])
5966 force = force .reshape ([numb_test ,- 1 ])
0 commit comments