@@ -52,52 +52,16 @@ def compute_mae_stress(delta: jnp.ndarray, mask) -> float:
52
52
return _masked_mean_stress (jnp .abs (delta ), mask )
53
53
54
54
55
- def compute_rel_mae (delta : jnp .ndarray , target_val : jnp .ndarray , mask ) -> float :
56
- target_norm = _masked_mean (jnp .abs (target_val ), mask )
57
- return _masked_mean (jnp .abs (delta ), mask ) / (target_norm + 1e-30 )
55
+ def compute_mse (delta : jnp .ndarray , mask ) -> float :
56
+ return _masked_mean (jnp .square (delta ), mask )
58
57
59
58
60
- def compute_rel_mae_f (delta : jnp .ndarray , target_val : jnp .ndarray , mask ) -> float :
61
- target_norm = _masked_mean_f (jnp .abs (target_val ), mask )
62
- return _masked_mean_f (jnp .abs (delta ), mask ) / (target_norm + 1e-30 )
59
+ def compute_mse_f (delta : jnp .ndarray , mask ) -> float :
60
+ return _masked_mean_f (jnp .square (delta ), mask )
63
61
64
62
65
- def compute_rel_mae_stress (delta : jnp .ndarray , target_val : jnp .ndarray , mask ) -> float :
66
- target_norm = _masked_mean_stress (jnp .abs (target_val ), mask )
67
- return _masked_mean_stress (jnp .abs (delta ), mask ) / (target_norm + 1e-30 )
68
-
69
-
70
- def compute_rmse (delta : jnp .ndarray , mask ) -> float :
71
- return jnp .sqrt (_masked_mean (jnp .square (delta ), mask ))
72
-
73
-
74
- def compute_rmse_f (delta : jnp .ndarray , mask ) -> float :
75
- return jnp .sqrt (_masked_mean_f (jnp .square (delta ), mask ))
76
-
77
-
78
- def compute_rmse_stress (delta : jnp .ndarray , mask ) -> float :
79
- return jnp .sqrt (_masked_mean_stress (jnp .square (delta ), mask ))
80
-
81
-
82
- def compute_rel_rmse (delta : jnp .ndarray , target_val : jnp .ndarray , mask ) -> float :
83
- target_norm = jnp .sqrt (_masked_mean (jnp .square (target_val ), mask ))
84
- return jnp .sqrt (_masked_mean (jnp .square (delta ), mask )) / (target_norm + 1e-30 )
85
-
86
-
87
- def compute_rel_rmse_f (delta : jnp .ndarray , target_val : jnp .ndarray , mask ) -> float :
88
- target_norm = jnp .sqrt (_masked_mean_f (jnp .square (target_val ), mask ))
89
- return jnp .sqrt (_masked_mean_f (jnp .square (delta ), mask )) / (target_norm + 1e-30 )
90
-
91
-
92
- def compute_rel_rmse_stress (delta : jnp .ndarray , target_val : jnp .ndarray , mask ) -> float :
93
- target_norm = jnp .sqrt (_masked_mean_stress (jnp .square (target_val ), mask ))
94
- return jnp .sqrt (_masked_mean_stress (jnp .square (delta ), mask )) / (
95
- target_norm + 1e-30
96
- )
97
-
98
-
99
- def compute_q95 (delta : jnp .ndarray ) -> float :
100
- return jnp .percentile (jnp .abs (delta ), q = 95 )
63
+ def compute_mse_stress (delta : jnp .ndarray , mask ) -> float :
64
+ return _masked_mean_stress (jnp .square (delta ), mask )
101
65
102
66
103
67
def _sum_nodes_of_the_same_graph (
@@ -295,120 +259,66 @@ def compute_eval_metrics(
295
259
stress_per_atom_list .append (ref_graph .globals .stress / jnp .sum (node_mask ))
296
260
297
261
metrics = {
298
- "mae_e" : None ,
299
- "rel_mae_e" : None ,
300
- "mae_e_per_atom" : None ,
301
- "rel_mae_e_per_atom" : None ,
302
- "rmse_e" : None ,
303
- "rel_rmse_e" : None ,
304
- "rmse_e_per_atom" : None ,
305
- "rel_rmse_e_per_atom" : None ,
306
- "q95_e" : None ,
307
- "mae_f" : None ,
308
- "rel_mae_f" : None ,
309
- "rmse_f" : None ,
310
- "rel_rmse_f" : None ,
311
- "q95_f" : None ,
312
- "mae_stress" : None ,
313
- "rel_mae_stress" : None ,
314
- "mae_stress_per_atom" : None ,
315
- "rel_mae_stress_per_atom" : None ,
316
- "rmse_stress" : None ,
317
- "rel_rmse_stress" : None ,
318
- "rmse_stress_per_atom" : None ,
319
- "rel_rmse_stress_per_atom" : None ,
320
- "q95_stress" : None ,
262
+ "mae_e" : jnp .nan ,
263
+ "mae_e_per_atom" : jnp .nan ,
264
+ "mse_e" : jnp .nan ,
265
+ "mse_e_per_atom" : jnp .nan ,
266
+ "mae_f" : jnp .nan ,
267
+ "mse_f" : jnp .nan ,
268
+ "mae_stress" : jnp .nan ,
269
+ "mae_stress_per_atom" : jnp .nan ,
270
+ "mse_stress" : jnp .nan ,
271
+ "mse_stress_per_atom" : jnp .nan ,
321
272
}
322
273
323
274
if len (delta_es_list ) > 0 :
324
275
delta_es = jnp .concatenate (delta_es_list , axis = 0 )
325
276
delta_es_per_atom = jnp .concatenate (delta_es_per_atom_list , axis = 0 )
326
- es = jnp .concatenate (es_list , axis = 0 )
327
- es_per_atom = jnp .concatenate (es_per_atom_list , axis = 0 )
328
277
329
278
metrics .update (
330
279
{
331
280
# Mean absolute error
332
281
"mae_e" : compute_mae (delta_es , graph_mask ),
333
- # Root-mean -square error
334
- "rmse_e " : compute_rmse (delta_es , graph_mask ),
282
+ # Mean -square error
283
+ "mse_e " : compute_mse (delta_es , graph_mask ),
335
284
}
336
285
)
337
286
if extended_metrics :
338
287
metrics .update (
339
288
{
340
289
# Mean absolute error
341
- "rel_mae_e" : compute_rel_mae (delta_es , es , graph_mask ),
342
290
"mae_e_per_atom" : compute_mae (delta_es_per_atom , graph_mask ),
343
- "rel_mae_e_per_atom" : compute_rel_mae (
344
- delta_es_per_atom , es_per_atom , graph_mask
345
- ),
346
- # Root-mean-square error
347
- "rel_rmse_e" : compute_rel_rmse (delta_es , es , graph_mask ),
348
- "rmse_e_per_atom" : compute_rmse (delta_es_per_atom , graph_mask ),
349
- "rel_rmse_e_per_atom" : compute_rel_rmse (
350
- delta_es_per_atom , es_per_atom , graph_mask
351
- ),
352
- # Q_95
353
- "q95_e" : compute_q95 (delta_es ),
291
+ # Mean-square error
292
+ "mse_e_per_atom" : compute_mse (delta_es_per_atom , graph_mask ),
354
293
}
355
294
)
356
295
357
296
if len (delta_fs_list ) > 0 :
358
297
delta_fs = jnp .concatenate (delta_fs_list , axis = 0 )
359
- fs = jnp .concatenate (fs_list , axis = 0 )
360
-
361
298
metrics .update (
362
299
{
363
300
# Mean absolute error
364
301
"mae_f" : compute_mae_f (delta_fs , node_mask ),
365
- # Root-mean -square error
366
- "rmse_f " : compute_rmse_f (delta_fs , node_mask ),
302
+ # Mean -square error
303
+ "mse_f " : compute_mse_f (delta_fs , node_mask ),
367
304
}
368
305
)
369
- if extended_metrics :
370
- metrics .update (
371
- {
372
- # Mean absolute error
373
- "rel_mae_f" : compute_rel_mae_f (delta_fs , fs , node_mask ),
374
- # Root-mean-square error
375
- "rel_rmse_f" : compute_rel_rmse_f (delta_fs , fs , node_mask ),
376
- # Q_95
377
- "q95_f" : compute_q95 (delta_fs ),
378
- }
379
- )
380
306
381
307
if len (delta_stress_list ) > 0 and extended_metrics :
382
308
delta_stress = jnp .concatenate (delta_stress_list , axis = 0 )
383
309
delta_stress_per_atom = jnp .concatenate (delta_stress_per_atom_list , axis = 0 )
384
- stress = jnp .concatenate (stress_list , axis = 0 )
385
- stress_per_atom = jnp .concatenate (stress_per_atom_list , axis = 0 )
386
310
metrics .update (
387
311
{
388
312
# Mean absolute error
389
313
"mae_stress" : compute_mae_stress (delta_stress , graph_mask ),
390
- "rel_mae_stress" : compute_rel_mae_stress (
391
- delta_stress , stress , graph_mask
392
- ),
393
314
"mae_stress_per_atom" : compute_mae_stress (
394
315
delta_stress_per_atom , graph_mask
395
316
),
396
- "rel_mae_stress_per_atom" : compute_rel_mae_stress (
397
- delta_stress_per_atom , stress_per_atom , graph_mask
398
- ),
399
- # Root-mean-square error
400
- "rmse_stress" : compute_rmse_stress (delta_stress , graph_mask ),
401
- "rel_rmse_stress" : compute_rel_rmse_stress (
402
- delta_stress , stress , graph_mask
403
- ),
404
- "rmse_stress_per_atom" : compute_rmse_stress (
317
+ # Mean-square error
318
+ "mse_stress" : compute_mse_stress (delta_stress , graph_mask ),
319
+ "mse_stress_per_atom" : compute_mse_stress (
405
320
delta_stress_per_atom , graph_mask
406
321
),
407
- "rel_rmse_stress_per_atom" : compute_rel_rmse_stress (
408
- delta_stress_per_atom , stress_per_atom , graph_mask
409
- ),
410
- # Q_95
411
- "q95_stress" : compute_q95 (delta_stress ),
412
322
}
413
323
)
414
324
0 commit comments