Skip to content

Commit ae557ed

Browse files
committed
add torchsim integration and tests
1 parent a5634fe commit ae557ed

File tree

4 files changed

+125
-7
lines changed

4 files changed

+125
-7
lines changed

README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,31 @@ calculator = NequIPPolCalculator.from_compiled_model(
6060
)
6161
```
6262

63+
### Compile for TorchSim
64+
65+
Use the `allegro-pol` batched target to generate a compiled model for
66+
`allegro_pol.integrations.torchsim.NequIPPolTorchSimCalc`:
67+
68+
```bash
69+
nequip-compile \
70+
path/to/model.ckpt \
71+
path/to/compiled_model.nequip.pt2 \
72+
--device cuda \
73+
--mode aotinductor \
74+
--target batch_pol_bc
75+
```
76+
77+
Then load it with the `allegro-pol` TorchSim calculator:
78+
79+
```python
80+
from allegro_pol.integrations.torchsim import NequIPPolTorchSimCalc
81+
82+
calculator = NequIPPolTorchSimCalc.from_compiled_model(
83+
compile_path="path/to/compiled_model.nequip.pt2",
84+
device="cuda", # or "cpu"
85+
)
86+
```
87+
6388
### Compile for LAMMPS pair styles
6489

6590
Use an `allegro-pol` LAMMPS target with `nequip-compile` for pair-style integrations:

allegro_pol/_compile.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from nequip.data import AtomicDataDict
33
from nequip.scripts._compile_utils import (
44
ASE_OUTPUTS,
5+
BATCH_INPUTS,
56
LMP_OUTPUTS,
67
PAIR_NEQUIP_INPUTS,
8+
batched_data_settings,
79
single_frame_data_settings,
810
single_frame_batch_map_settings,
911
register_compile_targets,
@@ -16,6 +18,7 @@
1618
AOTI_PAIR_ALLEGRO_POL_TARGET = "pair_allegro_pol"
1719
AOTI_PAIR_ALLEGRO_POL_BC_TARGET = "pair_allegro_pol_bc"
1820
AOTI_ASE_POL_BC_TARGET = "ase_pol_bc"
21+
AOTI_BATCH_POL_BC_TARGET = "batch_pol_bc"
1922

2023

2124
PAIR_ALLEGRO_POL_OUTPUTS = [*LMP_OUTPUTS, AtomicDataDict.POLARIZATION_KEY]
@@ -53,10 +56,18 @@
5356
"data_settings": single_frame_data_settings,
5457
}
5558

59+
BATCH_POL_BC_TARGET = {
60+
"input": BATCH_INPUTS,
61+
"output": ASE_POL_BC_OUTPUTS,
62+
"batch_map_settings": lambda batch_map: batch_map, # no static shapes
63+
"data_settings": batched_data_settings,
64+
}
65+
5666
register_compile_targets(
5767
{
5868
AOTI_PAIR_ALLEGRO_POL_TARGET: PAIR_ALLEGRO_POL_TARGET,
5969
AOTI_PAIR_ALLEGRO_POL_BC_TARGET: PAIR_ALLEGRO_POL_BC_TARGET,
6070
AOTI_ASE_POL_BC_TARGET: ASE_POL_BC_TARGET,
71+
AOTI_BATCH_POL_BC_TARGET: BATCH_POL_BC_TARGET,
6172
}
6273
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# This file is a part of the `allegro-pol` package. Please see LICENSE and README at the root for information on using it.
2+
from typing import Dict
3+
4+
import torch
5+
6+
from nequip.data import AtomicDataDict
7+
from nequip.integrations.torchsim import NequIPTorchSimCalc
8+
from nequip.scripts._compile_utils import COMPILE_TARGET_DICT
9+
10+
from .._compile import AOTI_BATCH_POL_BC_TARGET
11+
from .._keys import POLARIZABILITY_KEY
12+
13+
14+
class NequIPPolTorchSimCalc(NequIPTorchSimCalc):
15+
@classmethod
16+
def _get_aoti_compile_target(cls) -> Dict:
17+
return COMPILE_TARGET_DICT[AOTI_BATCH_POL_BC_TARGET]
18+
19+
def save_extra_outputs(
20+
self, out: dict[str, torch.Tensor], results: dict[str, torch.Tensor]
21+
) -> None:
22+
if AtomicDataDict.POLARIZATION_KEY in out:
23+
results[AtomicDataDict.POLARIZATION_KEY] = out[
24+
AtomicDataDict.POLARIZATION_KEY
25+
].detach()
26+
27+
if AtomicDataDict.BORN_CHARGE_KEY in out:
28+
results[AtomicDataDict.BORN_CHARGE_KEY] = out[
29+
AtomicDataDict.BORN_CHARGE_KEY
30+
].detach()
31+
32+
if POLARIZABILITY_KEY in out:
33+
results[POLARIZABILITY_KEY] = out[POLARIZABILITY_KEY].detach()

tests/model/test_allegro_pol.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
# This file is a part of the `allegro-pol` package. Please see LICENSE and README at the root for information on using it.
22
import pytest
3-
from nequip.utils.unittests.model_tests_compilation import CompilationTestsMixin
3+
from nequip.data import AtomicDataDict
4+
from nequip.utils.unittests.model_tests_ase_integration import ASEIntegrationMixin
5+
from nequip.utils.unittests.model_tests_train_time_compile import TrainTimeCompileMixin
6+
from nequip.utils.unittests.model_tests_torchsim import TorchSimIntegrationMixin
47
from nequip.utils.versions import _TORCH_GE_2_6
58

9+
from allegro_pol._keys import POLARIZABILITY_KEY
10+
611
_CUEQ_INSTALLED = False
712

813
if _TORCH_GE_2_6:
@@ -51,15 +56,17 @@
5156
)
5257

5358

54-
class TestAllegroPol(CompilationTestsMixin):
59+
class TestAllegroPol(
60+
TrainTimeCompileMixin, ASEIntegrationMixin, TorchSimIntegrationMixin
61+
):
5562
"""Test suite for Allegro Polarization models"""
5663

5764
@pytest.fixture
5865
def strict_locality(self):
5966
return True
6067

6168
@pytest.fixture(scope="class")
62-
def nequip_compile_tol(self, model_dtype):
69+
def ase_integration_tol(self, model_dtype):
6370
return {"float32": 5e-5, "float64": 1e-10}[model_dtype]
6471

6572
@pytest.fixture(scope="class")
@@ -69,11 +76,49 @@ def ase_calculator_cls(self):
6976
return NequIPPolCalculator
7077

7178
@pytest.fixture(scope="class")
72-
def ase_aoti_compile_target(self):
79+
def ase_aoti_target(self):
7380
from allegro_pol._compile import AOTI_ASE_POL_BC_TARGET
7481

7582
return AOTI_ASE_POL_BC_TARGET
7683

84+
@pytest.fixture(scope="class")
85+
def ase_properties_to_compare(self):
86+
return [
87+
"energy",
88+
"forces",
89+
AtomicDataDict.POLARIZATION_KEY,
90+
AtomicDataDict.BORN_CHARGE_KEY,
91+
]
92+
93+
@pytest.fixture(scope="class")
94+
def torchsim_calculator_cls(self):
95+
from allegro_pol.integrations.torchsim import NequIPPolTorchSimCalc
96+
97+
return NequIPPolTorchSimCalc
98+
99+
@pytest.fixture(scope="class")
100+
def torchsim_reference_ase_calculator_cls(self):
101+
from allegro_pol.integrations.ase import NequIPPolCalculator
102+
103+
return NequIPPolCalculator
104+
105+
@pytest.fixture(scope="class")
106+
def torchsim_aoti_target(self):
107+
from allegro_pol._compile import AOTI_BATCH_POL_BC_TARGET
108+
109+
return AOTI_BATCH_POL_BC_TARGET
110+
111+
@pytest.fixture(scope="class")
112+
def torchsim_properties_to_compare(self):
113+
return [
114+
"energy",
115+
"forces",
116+
"stress",
117+
AtomicDataDict.POLARIZATION_KEY,
118+
AtomicDataDict.BORN_CHARGE_KEY,
119+
POLARIZABILITY_KEY,
120+
]
121+
77122
@pytest.fixture(
78123
scope="class",
79124
params=[None]
@@ -83,7 +128,7 @@ def ase_aoti_compile_target(self):
83128
else []
84129
),
85130
)
86-
def nequip_compile_acceleration_modifiers(self, request):
131+
def ase_compile_modifiers(self, request):
87132
if request.param is None:
88133
return None
89134

@@ -103,6 +148,10 @@ def modifier_handler(mode, device, model_dtype):
103148

104149
return modifier_handler
105150

151+
@pytest.fixture(scope="class")
152+
def torchsim_compile_modifiers(self, ase_compile_modifiers):
153+
return ase_compile_modifiers
154+
106155
@pytest.fixture(
107156
scope="class",
108157
params=[None]
@@ -112,11 +161,11 @@ def modifier_handler(mode, device, model_dtype):
112161
else []
113162
),
114163
)
115-
def train_time_compile_acceleration_modifiers(self, request):
164+
def train_time_compile_modifiers(self, request):
116165
if request.param is None:
117166
return None
118167

119-
def modifier_handler(device):
168+
def modifier_handler(device, model_dtype):
120169
if request.param == "enable_CuEquivarianceContracter":
121170
if device == "cpu":
122171
pytest.skip("CuEquivarianceContracter tests skipped for CPU")

0 commit comments

Comments
 (0)