1818def test (args ):
1919 de = DeepEval (args .model )
2020 all_sys = expand_sys_str (args .system )
21+ err_coll = []
22+ siz_coll = []
2123 for ii in all_sys :
2224 args .system = ii
2325 print ("# ---------------output of dp test--------------- " )
2426 print ("# testing system : " + ii )
2527 if de .model_type == 'ener' :
26- test_ener (args )
28+ err , siz = test_ener (args )
2729 elif de .model_type == 'dipole' :
28- test_dipole (args )
30+ err , siz = test_dipole (args )
2931 elif de .model_type == 'polar' :
30- test_polar (args )
32+ err , siz = test_polar (args )
3133 elif de .model_type == 'wfc' :
32- test_wfc (args )
34+ err , siz = test_wfc (args )
35+ else :
36+ raise RuntimeError ('unknow model type ' + de .model_type )
37+ print ("# ----------------------------------------------- " )
38+ err_coll .append (err )
39+ siz_coll .append (siz )
40+ avg_err = weighted_average (err_coll , siz_coll )
41+ if len (all_sys ) > 1 :
42+ print ("# ----------weighted average of errors----------- " )
43+ print ("# number of systems : %d" % len (all_sys ))
44+ if de .model_type == 'ener' :
45+ print_ener_sys_avg (avg_err )
46+ elif de .model_type == 'dipole' :
47+ print_dipole_sys_avg (avg_err )
48+ elif de .model_type == 'polar' :
49+ print_polar_sys_avg (avg_err )
50+ elif de .model_type == 'wfc' :
51+ print_wfc_sys_avg (avg_err )
3352 else :
3453 raise RuntimeError ('unknow model type ' + de .model_type )
3554 print ("# ----------------------------------------------- " )
@@ -39,6 +58,23 @@ def l2err (diff) :
3958 return np .sqrt (np .average (diff * diff ))
4059
4160
61+ def weighted_average (err_coll , siz_coll ):
62+ nsys = len (err_coll )
63+ nitems = len (err_coll [0 ])
64+ assert (len (err_coll ) == len (siz_coll ))
65+ sum_err = np .zeros (nitems )
66+ sum_siz = np .zeros (nitems )
67+ for sys_error , sys_size in zip (err_coll , siz_coll ):
68+ for ii in range (nitems ):
69+ ee = sys_error [ii ]
70+ ss = sys_size [ii ]
71+ sum_err [ii ] += ee * ee * ss
72+ sum_siz [ii ] += ss
73+ for ii in range (nitems ):
74+ sum_err [ii ] = np .sqrt (sum_err [ii ] / sum_siz [ii ])
75+ return sum_err
76+
77+
4278def test_ener (args ) :
4379 if args .rand_seed is not None :
4480 np .random .seed (args .rand_seed % (2 ** 32 ))
@@ -121,6 +157,13 @@ def test_ener (args) :
121157 axis = 1 )
122158 np .savetxt (detail_file + ".v.out" , pv ,
123159 header = 'data_vxx data_vxy data_vxz data_vyx data_vyy data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz' )
160+ return [l2ea , l2f , l2va ], [energy .size , force .size , virial .size ]
161+
162+
163+ def print_ener_sys_avg (avg ):
164+ print ("Energy L2err/Natoms : %e eV" % avg [0 ])
165+ print ("Force L2err : %e eV/A" % avg [1 ])
166+ print ("Virial L2err/Natoms : %e eV" % avg [2 ])
124167
125168
126169def test_wfc (args ) :
@@ -154,6 +197,11 @@ def test_wfc (args) :
154197 axis = 1 )
155198 np .savetxt (detail_file + ".out" , pe ,
156199 header = 'ref_wfc(12 dofs) predicted_wfc(12 dofs)' )
200+ return [l2f ], [wfc .size ]
201+
202+
203+ def print_wfc_sys_avg (avg ):
204+ print ("WFC L2err : %e eV/A" % avg [0 ])
157205
158206
159207def test_polar (args ) :
@@ -187,6 +235,11 @@ def test_polar (args) :
187235 axis = 1 )
188236 np .savetxt (detail_file + ".out" , pe ,
189237 header = 'data_pxx data_pxy data_pxz data_pyx data_pyy data_pyz data_pzx data_pzy data_pzz pred_pxx pred_pxy pred_pxz pred_pyx pred_pyy pred_pyz pred_pzx pred_pzy pred_pzz' )
238+ return [l2f ], [polar .size ]
239+
240+
241+ def print_polar_sys_avg (avg ):
242+ print ("Polarizability L2err : %e eV/A" % avg [0 ])
190243
191244
192245def test_dipole (args ) :
@@ -220,3 +273,8 @@ def test_dipole (args) :
220273 axis = 1 )
221274 np .savetxt (detail_file + ".out" , pe ,
222275 header = 'data_x data_y data_z pred_x pred_y pred_z' )
276+ return [l2f ], [dipole .size ]
277+
278+
279+ def print_dipole_sys_avg (avg ):
280+ print ("Dipole L2err : %e eV/A" % avg [0 ])
0 commit comments