1010from deepmd .Data import DeepmdData
1111from deepmd import DeepEval
1212from deepmd import DeepPot
13+ from deepmd import DeepDipole
1314from deepmd import DeepPolar
1415from deepmd import DeepWFC
1516from tensorflow .python .framework import ops
@@ -18,6 +19,8 @@ def test (args):
1819 de = DeepEval (args .model )
1920 if de .model_type == 'ener' :
2021 test_ener (args )
22+ elif de .model_type == 'dipole' :
23+ test_dipole (args )
2124 elif de .model_type == 'polar' :
2225 test_polar (args )
2326 elif de .model_type == 'wfc' :
@@ -154,3 +157,36 @@ def test_polar (args) :
154157 axis = 1 )
155158 np .savetxt (detail_file + ".out" , pe ,
156159 header = 'data_pxx data_pxy data_pxz data_pyx data_pyy data_pyz data_pzx data_pzy data_pzz pred_pxx pred_pxy pred_pxz pred_pyx pred_pyy pred_pyz pred_pzx pred_pzy pred_pzz' )
160+
161+
162+ def test_dipole (args ) :
163+ if args .rand_seed is not None :
164+ np .random .seed (args .rand_seed % (2 ** 32 ))
165+
166+ dp = DeepDipole (args .model )
167+ data = DeepmdData (args .system , args .set_prefix , shuffle_test = args .shuffle_test )
168+ data .add ('dipole' , 3 , atomic = True , must = True , high_prec = False , type_sel = dp .get_sel_type ())
169+ test_data = data .get_test ()
170+ numb_test = args .numb_test
171+ natoms = len (test_data ["type" ][0 ])
172+ nframes = test_data ["box" ].shape [0 ]
173+ numb_test = min (nframes , numb_test )
174+
175+ coord = test_data ["coord" ][:numb_test ].reshape ([numb_test , - 1 ])
176+ box = test_data ["box" ][:numb_test ]
177+ atype = test_data ["type" ][0 ]
178+ dipole = dp .eval (coord , box , atype )
179+
180+ dipole = dipole .reshape ([numb_test ,- 1 ])
181+ l2f = (l2err (dipole - test_data ["dipole" ] [:numb_test ]))
182+
183+ print ("# number of test data : %d " % numb_test )
184+ print ("Dipole L2err : %e eV/A" % l2f )
185+
186+ detail_file = args .detail_file
187+ if detail_file is not None :
188+ pe = np .concatenate ((np .reshape (test_data ["dipole" ][:numb_test ], [- 1 ,3 ]),
189+ np .reshape (dipole , [- 1 ,3 ])),
190+ axis = 1 )
191+ np .savetxt (detail_file + ".out" , pe ,
192+ header = 'data_x data_y data_z pred_x pred_y pred_z' )
0 commit comments