Skip to content

Commit e2d98d9

Browse files
committed
Make sure functions with side-effects are included in TorchScript code
1 parent 3129b92 commit e2d98d9

File tree

3 files changed

+75
-10
lines changed

3 files changed

+75
-10
lines changed

metatomic-torch/src/register.cpp

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -208,16 +208,51 @@ TORCH_LIBRARY(metatomic, m) {
208208
m.def("version() -> str", metatomic_torch::version);
209209

210210
m.def("read_model_metadata(str path) -> __torch__.torch.classes.metatomic.ModelMetadata", read_model_metadata);
211-
m.def("check_atomistic_model(str path) -> ()", check_atomistic_model);
212-
m.def("load_model_extensions(str path, str? extensions_directory) -> ()", load_model_extensions);
213-
214211
m.def("unit_conversion_factor(str quantity, str from_unit, str to_unit) -> float", unit_conversion_factor);
215-
m.def(
216-
"register_autograd_neighbors("
217-
"__torch__.torch.classes.metatomic.System system, "
218-
"__torch__.torch.classes.metatensor.TensorBlock neighbors, "
219-
"bool check_consistency = False"
220-
") -> ()",
221-
register_autograd_neighbors
212+
213+
// manually construct the schema for "check_atomistic_model(str path) -> ()",
214+
// so we can set AliasAnalysisKind to CONSERVATIVE. In turn, this make it so
215+
// the TorchScript compiler knows this function has side-effects, and does
216+
// not remove it from the graph.
217+
auto schema = c10::FunctionSchema(
218+
/*name=*/"check_atomistic_model",
219+
/*overload_name=*/"check_atomistic_model",
220+
/*arguments=*/{
221+
c10::Argument("path", c10::getTypePtr<std::string>()),
222+
},
223+
/*returns=*/{}
224+
);
225+
schema.setAliasAnalysis(c10::AliasAnalysisKind::CONSERVATIVE);
226+
m.def(std::move(schema), check_atomistic_model);
227+
228+
// "load_model_extensions(str path, str? extensions_directory) -> ()"
229+
schema = c10::FunctionSchema(
230+
/*name=*/"load_model_extensions",
231+
/*overload_name=*/"load_model_extensions",
232+
/*arguments=*/{
233+
c10::Argument("path", c10::getTypePtr<std::string>()),
234+
c10::Argument("extensions_directory", c10::getTypePtr<c10::optional<std::string>>()),
235+
},
236+
/*returns=*/{}
237+
);
238+
schema.setAliasAnalysis(c10::AliasAnalysisKind::CONSERVATIVE);
239+
m.def(std::move(schema), load_model_extensions);
240+
241+
// "register_autograd_neighbors("
242+
// "__torch__.torch.classes.metatomic.System system, "
243+
// "__torch__.torch.classes.metatensor.TensorBlock neighbors, "
244+
// "bool check_consistency = False"
245+
// ") -> ()",
246+
schema = c10::FunctionSchema(
247+
/*name=*/"register_autograd_neighbors",
248+
/*overload_name=*/"register_autograd_neighbors",
249+
/*arguments=*/{
250+
c10::Argument("system", c10::getTypePtr<metatomic_torch::System>()),
251+
c10::Argument("neighbors", c10::getTypePtr<metatensor_torch::TensorBlock>()),
252+
c10::Argument("check_consistency", c10::getTypePtr<bool>(), c10::nullopt, /*default_value=*/false),
253+
},
254+
/*returns=*/{}
222255
);
256+
schema.setAliasAnalysis(c10::AliasAnalysisKind::CONSERVATIVE);
257+
m.def(std::move(schema), register_autograd_neighbors);
223258
}

python/metatomic_torch/tests/model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
check_atomistic_model,
1818
is_atomistic_model,
1919
load_atomistic_model,
20+
load_model_extensions,
2021
read_model_metadata,
2122
)
2223

@@ -126,6 +127,23 @@ def test_recreate(model, tmp_path):
126127
check_atomistic_model("export_new.pt")
127128

128129

130+
def test_torch_script():
131+
# make sure functions that have side effects are properly included in the
132+
# TorchScript code
133+
134+
@torch.jit.script
135+
def test_function(path: str):
136+
check_atomistic_model(path)
137+
138+
assert "ops.metatomic.check_atomistic_model" in test_function.code
139+
140+
@torch.jit.script
141+
def test_function(path: str, extensions_directory: Optional[str]):
142+
load_model_extensions(path, extensions_directory)
143+
144+
assert "ops.metatomic.load_model_extensions" in test_function.code
145+
146+
129147
def test_training_mode():
130148
model = MinimalModel()
131149
model.train(True)

python/metatomic_torch/tests/neighbors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
import torch
3+
from metatensor.torch import TensorBlock
34

45
from metatomic.torch import (
56
NeighborListOptions,
@@ -178,3 +179,14 @@ def test_neighbor_autograd_errors():
178179
)
179180
with pytest.raises(ValueError, match=message):
180181
register_autograd_neighbors(system, neighbors, check_consistency=True)
182+
183+
184+
def test_torch_script():
185+
# make sure functions that have side effects are properly included in the
186+
# TorchScript code
187+
188+
@torch.jit.script
189+
def test_function(system: System, neighbors: TensorBlock, check_consistency: bool):
190+
register_autograd_neighbors(system, neighbors, check_consistency)
191+
192+
assert "ops.metatomic.register_autograd_neighbors" in test_function.code

0 commit comments

Comments
 (0)