2121)
2222
2323
24- def calc_model_devi_f (fs : np .ndarray ) -> Tuple [np .ndarray ]:
24+ def calc_model_devi_f (
25+ fs : np .ndarray , real_f : Optional [np .ndarray ] = None
26+ ) -> Tuple [np .ndarray ]:
2527 """Calculate model deviation of force.
2628
2729 Parameters
2830 ----------
2931 fs : numpy.ndarray
3032 size of `n_models x n_frames x n_atoms x 3`
33+ real_f : numpy.ndarray or None
34+ real force, size of `n_frames x n_atoms x 3`. If given,
35+ the RMS real error is calculated instead.
3136
3237 Returns
3338 -------
@@ -38,14 +43,21 @@ def calc_model_devi_f(fs: np.ndarray) -> Tuple[np.ndarray]:
3843 avg_devi_f : numpy.ndarray
3944 average deviation of force in all atoms
4045 """
41- fs_devi = np .linalg .norm (np .std (fs , axis = 0 ), axis = - 1 )
46+ if real_f is None :
47+ fs_devi = np .linalg .norm (np .std (fs , axis = 0 ), axis = - 1 )
48+ else :
49+ fs_devi = np .linalg .norm (
50+ np .sqrt (np .mean (np .square (fs - real_f ), axis = 0 )), axis = - 1
51+ )
4252 max_devi_f = np .max (fs_devi , axis = - 1 )
4353 min_devi_f = np .min (fs_devi , axis = - 1 )
4454 avg_devi_f = np .mean (fs_devi , axis = - 1 )
4555 return max_devi_f , min_devi_f , avg_devi_f
4656
4757
48- def calc_model_devi_e (es : np .ndarray ) -> np .ndarray :
58+ def calc_model_devi_e (
59+ es : np .ndarray , real_e : Optional [np .ndarray ] = None
60+ ) -> np .ndarray :
4961 """Calculate model deviation of total energy per atom.
5062
5163 Here we don't use the atomic energy, as the decomposition
@@ -56,24 +68,35 @@ def calc_model_devi_e(es: np.ndarray) -> np.ndarray:
5668 ----------
5769 es : numpy.ndarray
5870 size of `n_models x n_frames x 1
71+ real_e : numpy.ndarray
72+ real energy, size of `n_frames x 1`. If given,
73+ the RMS real error is calculated instead.
5974
6075 Returns
6176 -------
6277 max_devi_e : numpy.ndarray
6378 maximum deviation of energy
6479 """
65- es_devi = np .std (es , axis = 0 )
80+ if real_e is None :
81+ es_devi = np .std (es , axis = 0 )
82+ else :
83+ es_devi = np .sqrt (np .mean (np .square (es - real_e ), axis = 0 ))
6684 es_devi = np .squeeze (es_devi , axis = - 1 )
6785 return es_devi
6886
6987
70- def calc_model_devi_v (vs : np .ndarray ) -> Tuple [np .ndarray ]:
88+ def calc_model_devi_v (
89+ vs : np .ndarray , real_v : Optional [np .ndarray ] = None
90+ ) -> Tuple [np .ndarray ]:
7191 """Calculate model deviation of virial.
7292
7393 Parameters
7494 ----------
7595 vs : numpy.ndarray
7696 size of `n_models x n_frames x 9`
97+ real_v : numpy.ndarray
98+ real virial, size of `n_frames x 9`. If given,
99+ the RMS real error is calculated instead.
77100
78101 Returns
79102 -------
@@ -84,7 +107,10 @@ def calc_model_devi_v(vs: np.ndarray) -> Tuple[np.ndarray]:
84107 avg_devi_v : numpy.ndarray
85108 average deviation of virial in 9 elements
86109 """
87- vs_devi = np .std (vs , axis = 0 )
110+ if real_v is None :
111+ vs_devi = np .std (vs , axis = 0 )
112+ else :
113+ vs_devi = np .sqrt (np .mean (np .square (vs - real_v ), axis = 0 ))
88114 max_devi_v = np .max (vs_devi , axis = - 1 )
89115 min_devi_v = np .min (vs_devi , axis = - 1 )
90116 avg_devi_v = np .linalg .norm (vs_devi , axis = - 1 ) / 3
@@ -148,6 +174,7 @@ def calc_model_devi(
148174 mixed_type = False ,
149175 fparam : Optional [np .ndarray ] = None ,
150176 aparam : Optional [np .ndarray ] = None ,
177+ real_data : Optional [dict ] = None ,
151178):
152179 """Python interface to calculate model deviation.
153180
@@ -171,6 +198,8 @@ def calc_model_devi(
171198 frame specific parameters
172199 aparam : numpy.ndarray
173200 atomic specific parameters
201+ real_data : dict, optional
202+ real data to calculate RMS real error
174203
175204 Returns
176205 -------
@@ -211,17 +240,29 @@ def calc_model_devi(
211240 virials = np .array (virials )
212241
213242 devi = [np .arange (coord .shape [0 ]) * frequency ]
214- devi += list (calc_model_devi_v (virials ))
215- devi += list (calc_model_devi_f (forces ))
216- devi .append (calc_model_devi_e (energies ))
243+ if real_data is None :
244+ devi += list (calc_model_devi_v (virials ))
245+ devi += list (calc_model_devi_f (forces ))
246+ devi .append (calc_model_devi_e (energies ))
247+ else :
248+ devi += list (calc_model_devi_v (virials , real_data ["virial" ]))
249+ devi += list (calc_model_devi_f (forces , real_data ["force" ]))
250+ devi .append (calc_model_devi_e (energies , real_data ["energy" ]))
217251 devi = np .vstack (devi ).T
218252 if fname :
219253 write_model_devi_out (devi , fname )
220254 return devi
221255
222256
223257def make_model_devi (
224- * , models : list , system : str , set_prefix : str , output : str , frequency : int , ** kwargs
258+ * ,
259+ models : list ,
260+ system : str ,
261+ set_prefix : str ,
262+ output : str ,
263+ frequency : int ,
264+ real_error : bool = False ,
265+ ** kwargs ,
225266):
226267 """Make model deviation calculation.
227268
@@ -239,6 +280,8 @@ def make_model_devi(
239280 The number of steps that elapse between writing coordinates
240281 in a trajectory by a MD engine (such as Gromacs / Lammps).
241282 This paramter is used to determine the index in the output file.
283+ real_error : bool, default: False
284+ If True, calculate the RMS real error instead of model deviation.
242285 **kwargs
243286 Arbitrary keyword arguments.
244287 """
@@ -279,6 +322,29 @@ def make_model_devi(
279322 must = True ,
280323 high_prec = False ,
281324 )
325+ if real_error :
326+ dp_data .add (
327+ "energy" ,
328+ 1 ,
329+ atomic = False ,
330+ must = False ,
331+ high_prec = True ,
332+ )
333+ dp_data .add (
334+ "force" ,
335+ 3 ,
336+ atomic = True ,
337+ must = False ,
338+ high_prec = False ,
339+ )
340+ dp_data .add (
341+ "virial" ,
342+ 9 ,
343+ atomic = False ,
344+ must = False ,
345+ high_prec = False ,
346+ )
347+
282348 mixed_type = dp_data .mixed_type
283349
284350 data_sets = [dp_data ._load_set (set_name ) for set_name in dp_data .dirs ]
@@ -301,6 +367,15 @@ def make_model_devi(
301367 aparam = data ["aparam" ]
302368 else :
303369 aparam = None
370+ if real_error :
371+ natoms = atype .shape [- 1 ]
372+ real_data = {
373+ "energy" : data ["energy" ] / natoms ,
374+ "force" : data ["force" ].reshape ([- 1 , natoms , 3 ]),
375+ "virial" : data ["virial" ] / natoms ,
376+ }
377+ else :
378+ real_data = None
304379 devi = calc_model_devi (
305380 coord ,
306381 box ,
@@ -309,6 +384,7 @@ def make_model_devi(
309384 mixed_type = mixed_type ,
310385 fparam = fparam ,
311386 aparam = aparam ,
387+ real_data = real_data ,
312388 )
313389 nframes_tot += coord .shape [0 ]
314390 devis .append (devi )
0 commit comments