1212from deepmd import DeepPot
1313from deepmd import DeepDipole
1414from deepmd import DeepPolar
15+ from deepmd import DeepGlobalPolar
1516from deepmd import DeepWFC
1617from tensorflow .python .framework import ops
1718
@@ -28,6 +29,8 @@ def test (args):
2829 dp = DeepDipole (args .model )
2930 elif de .model_type == 'polar' :
3031 dp = DeepPolar (args .model )
32+ elif de .model_type == 'global_polar' :
33+ dp = DeepGlobalPolar (args .model )
3134 elif de .model_type == 'wfc' :
3235 dp = DeepWFC (args .model )
3336 else :
@@ -41,7 +44,9 @@ def test (args):
4144 elif de .model_type == 'dipole' :
4245 err , siz = test_dipole (dp , args )
4346 elif de .model_type == 'polar' :
44- err , siz = test_polar (dp , args )
47+ err , siz = test_polar (dp , args , global_polar = False )
48+ elif de .model_type == 'global_polar' :
49+ err , siz = test_polar (dp , args , global_polar = True )
4550 elif de .model_type == 'wfc' :
4651 err , siz = test_wfc (dp , args )
4752 else :
@@ -61,6 +66,8 @@ def test (args):
6166 print_dipole_sys_avg (avg_err )
6267 elif de .model_type == 'polar' :
6368 print_polar_sys_avg (avg_err )
69+ elif de .model_type == 'global_polar' :
70+ print_polar_sys_avg (avg_err )
6471 elif de .model_type == 'wfc' :
6572 print_wfc_sys_avg (avg_err )
6673 else :
@@ -223,12 +230,15 @@ def print_wfc_sys_avg(avg):
223230 print ("WFC L2err : %e eV/A" % avg [0 ])
224231
225232
226- def test_polar (dp , args ) :
233+ def test_polar (dp , args , global_polar = False ) :
227234 if args .rand_seed is not None :
228235 np .random .seed (args .rand_seed % (2 ** 32 ))
229236
230237 data = DeepmdData (args .system , args .set_prefix , shuffle_test = args .shuffle_test )
231- data .add ('polarizability' , 9 , atomic = True , must = True , high_prec = False , type_sel = dp .get_sel_type ())
238+ if not global_polar :
239+ data .add ('polarizability' , 9 , atomic = True , must = True , high_prec = False , type_sel = dp .get_sel_type ())
240+ else :
241+ data .add ('polarizability' , 9 , atomic = False , must = True , high_prec = False , type_sel = dp .get_sel_type ())
232242 test_data = data .get_test ()
233243 numb_test = args .numb_test
234244 natoms = len (test_data ["type" ][0 ])
@@ -239,12 +249,21 @@ def test_polar (dp, args) :
239249 box = test_data ["box" ][:numb_test ]
240250 atype = test_data ["type" ][0 ]
241251 polar = dp .eval (coord , box , atype )
252+ sel_type = dp .get_sel_type ()
253+ sel_natoms = 0
254+ for ii in sel_type :
255+ sel_natoms += sum (atype == ii )
242256
243257 polar = polar .reshape ([numb_test ,- 1 ])
244258 l2f = (l2err (polar - test_data ["polarizability" ] [:numb_test ]))
259+ l2fs = l2f / np .sqrt (sel_natoms )
260+ l2fa = l2f / sel_natoms
245261
246262 print ("# number of test data : %d " % numb_test )
247- print ("Polarizability L2err : %e eV/A" % l2f )
263+ print ("Polarizability L2err : %e eV/A" % l2f )
264+ if global_polar :
265+ print ("Polarizability L2err/sqrtN : %e eV/A" % l2fs )
266+ print ("Polarizability L2err/N : %e eV/A" % l2fa )
248267
249268 detail_file = args .detail_file
250269 if detail_file is not None :
0 commit comments