@@ -208,16 +208,51 @@ TORCH_LIBRARY(metatomic, m) {
208208 m.def (" version() -> str" , metatomic_torch::version);
209209
210210 m.def (" read_model_metadata(str path) -> __torch__.torch.classes.metatomic.ModelMetadata" , read_model_metadata);
211- m.def (" check_atomistic_model(str path) -> ()" , check_atomistic_model);
212- m.def (" load_model_extensions(str path, str? extensions_directory) -> ()" , load_model_extensions);
213-
214211 m.def (" unit_conversion_factor(str quantity, str from_unit, str to_unit) -> float" , unit_conversion_factor);
215- m.def (
216- " register_autograd_neighbors("
217- " __torch__.torch.classes.metatomic.System system, "
218- " __torch__.torch.classes.metatensor.TensorBlock neighbors, "
219- " bool check_consistency = False"
220- " ) -> ()" ,
221- register_autograd_neighbors
212+
213+ // manually construct the schema for "check_atomistic_model(str path) -> ()",
214+ // so we can set AliasAnalysisKind to CONSERVATIVE. In turn, this make it so
215+ // the TorchScript compiler knows this function has side-effects, and does
216+ // not remove it from the graph.
217+ auto schema = c10::FunctionSchema (
218+ /* name=*/ " check_atomistic_model" ,
219+ /* overload_name=*/ " check_atomistic_model" ,
220+ /* arguments=*/ {
221+ c10::Argument (" path" , c10::getTypePtr<std::string>()),
222+ },
223+ /* returns=*/ {}
224+ );
225+ schema.setAliasAnalysis (c10::AliasAnalysisKind::CONSERVATIVE);
226+ m.def (std::move (schema), check_atomistic_model);
227+
228+ // "load_model_extensions(str path, str? extensions_directory) -> ()"
229+ schema = c10::FunctionSchema (
230+ /* name=*/ " load_model_extensions" ,
231+ /* overload_name=*/ " load_model_extensions" ,
232+ /* arguments=*/ {
233+ c10::Argument (" path" , c10::getTypePtr<std::string>()),
234+ c10::Argument (" extensions_directory" , c10::getTypePtr<c10::optional<std::string>>()),
235+ },
236+ /* returns=*/ {}
237+ );
238+ schema.setAliasAnalysis (c10::AliasAnalysisKind::CONSERVATIVE);
239+ m.def (std::move (schema), load_model_extensions);
240+
241+ // "register_autograd_neighbors("
242+ // "__torch__.torch.classes.metatomic.System system, "
243+ // "__torch__.torch.classes.metatensor.TensorBlock neighbors, "
244+ // "bool check_consistency = False"
245+ // ") -> ()",
246+ schema = c10::FunctionSchema (
247+ /* name=*/ " register_autograd_neighbors" ,
248+ /* overload_name=*/ " register_autograd_neighbors" ,
249+ /* arguments=*/ {
250+ c10::Argument (" system" , c10::getTypePtr<metatomic_torch::System>()),
251+ c10::Argument (" neighbors" , c10::getTypePtr<metatensor_torch::TensorBlock>()),
252+ c10::Argument (" check_consistency" , c10::getTypePtr<bool >(), c10::nullopt , /* default_value=*/ false ),
253+ },
254+ /* returns=*/ {}
222255 );
256+ schema.setAliasAnalysis (c10::AliasAnalysisKind::CONSERVATIVE);
257+ m.def (std::move (schema), register_autograd_neighbors);
223258}
0 commit comments