13
13
from pymatgen .analysis .structure_matcher import StructureMatcher
14
14
from pymatgen .core import Structure
15
15
16
+ from atomate2 .forcefields import MLFF
16
17
from atomate2 .forcefields .md import (
17
18
CHGNetMDMaker ,
18
19
ForceFieldMDMaker ,
24
25
)
25
26
26
27
name_to_maker = {
27
- " CHGNet" : CHGNetMDMaker ,
28
- " M3GNet" : M3GNetMDMaker ,
29
- " MACE" : MACEMDMaker ,
30
- " GAP" : GAPMDMaker ,
31
- " NEP" : NEPMDMaker ,
32
- " Nequip" : NequipMDMaker ,
28
+ MLFF . CHGNet : CHGNetMDMaker ,
29
+ MLFF . M3GNet : M3GNetMDMaker ,
30
+ MLFF . MACE : MACEMDMaker ,
31
+ MLFF . GAP : GAPMDMaker ,
32
+ MLFF . NEP : NEPMDMaker ,
33
+ MLFF . Nequip : NequipMDMaker ,
33
34
}
34
35
35
36
@@ -47,40 +48,42 @@ def test_maker_initialization():
47
48
)
48
49
49
50
50
- @pytest .mark .parametrize (
51
- "ff_name" ,
52
- ["CHGNet" , "M3GNet" , "MACE" , "GAP" , "NEP" , "Nequip" ],
53
- )
51
+ @pytest .mark .parametrize ("ff_name" , MLFF )
54
52
def test_ml_ff_md_maker (
55
53
ff_name , si_structure , sr_ti_o3_structure , al2_au_structure , test_dir , clean_dir
56
54
):
57
- if ff_name == "GAP" and sys .version_info >= (3 , 12 ):
55
+ if ff_name == MLFF .Forcefield :
56
+ return # nothing to test here, MLFF.Forcefield is just a generic placeholder
57
+ if ff_name == MLFF .GAP and sys .version_info >= (3 , 12 ):
58
58
pytest .skip (
59
59
"GAP model not compatible with Python 3.12, waiting on https://github.com/libAtoms/QUIP/issues/645"
60
60
)
61
+ if ff_name == MLFF .M3GNet :
62
+ pytest .skip ("M3GNet requires DGL which is PyTorch 2.4 incompatible" )
61
63
62
64
n_steps = 5
63
65
64
66
ref_energies_per_atom = {
65
- "CHGNet" : - 5.280157089233398 ,
66
- "M3GNet" : - 5.387282371520996 ,
67
- "MACE" : - 5.311369895935059 ,
68
- "GAP" : - 5.391255755606209 ,
69
- "NEP" : - 3.966232215741286 ,
70
- "Nequip" : - 8.84670181274414 ,
67
+ MLFF .CHGNet : - 5.280157089233398 ,
68
+ MLFF .M3GNet : - 5.387282371520996 ,
69
+ MLFF .MACE : - 5.311369895935059 ,
70
+ MLFF .GAP : - 5.391255755606209 ,
71
+ MLFF .NEP : - 3.966232215741286 ,
72
+ MLFF .Nequip : - 8.84670181274414 ,
73
+ MLFF .SevenNet : - 5.394115447998047 ,
71
74
}
72
75
73
76
# ASE can slightly change tolerances on structure positions
74
77
matcher = StructureMatcher ()
75
78
76
79
calculator_kwargs = {}
77
80
unit_cell_structure = si_structure .copy ()
78
- if ff_name == " GAP" :
81
+ if ff_name == MLFF . GAP :
79
82
calculator_kwargs = {
80
83
"args_str" : "IP GAP" ,
81
84
"param_filename" : str (test_dir / "forcefields" / "gap" / "gap_file.xml" ),
82
85
}
83
- elif ff_name == " NEP" :
86
+ elif ff_name == MLFF . NEP :
84
87
# NOTE: The test NEP model is specifically trained on 16 elemental metals
85
88
# thus a new Al2Au structure is added.
86
89
# The NEP model used for the tests is licensed under a
@@ -91,7 +94,7 @@ def test_ml_ff_md_maker(
91
94
"model_filename" : test_dir / "forcefields" / "nep" / "nep.txt"
92
95
}
93
96
unit_cell_structure = al2_au_structure .copy ()
94
- elif ff_name == " Nequip" :
97
+ elif ff_name == MLFF . Nequip :
95
98
calculator_kwargs = {
96
99
"model_path" : test_dir / "forcefields" / "nequip" / "nequip_ff_sr_ti_o3.pth"
97
100
}
@@ -137,9 +140,9 @@ def test_ml_ff_md_maker(
137
140
for key in ("energy" , "forces" , "stress" , "velocities" , "temperature" )
138
141
for step in task_doc .objects ["trajectory" ].frame_properties
139
142
)
140
-
141
- with pytest .warns (FutureWarning ):
142
- name_to_maker [ ff_name ] ()
143
+ if ff_maker := name_to_maker . get ( ff_name ):
144
+ with pytest .warns (FutureWarning ):
145
+ ff_maker ()
143
146
144
147
145
148
@pytest .mark .parametrize ("traj_file" , ["trajectory.json.gz" , "atoms.traj" ])
0 commit comments