Skip to content

Commit b8307a1

Browse files
authored
Enable pickling of ModelPtr (#2985)
Enable pickling of `ModelPtr` for use in multiprocessing contexts. Related to #1126.
1 parent 785ab4b commit b8307a1

File tree

5 files changed

+108
-1
lines changed

5 files changed

+108
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ See also our [versioning policy](https://amici.readthedocs.io/en/latest/versioni
6161
This is a wrapper for both `amici.run_simulation` and
6262
`amici.run_simulations`, depending on the type of the `edata` argument.
6363
It also supports passing some `Solver` options as keyword arguments.
64+
* `amici.ModelPtr` now supports sufficient pickling for use in
65+
multi-processing contexts. This works only if the amici-generated model
66+
package exists in the same file system location and does not change until
67+
unpickling.
6468

6569
## v0.X Series
6670

python/sdist/amici/swig_wrappers.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
import warnings
88
from collections.abc import Sequence
9+
from pathlib import Path
910
from typing import Any
1011

1112
import amici
@@ -342,3 +343,57 @@ def _Model__simulate(
342343
solver=_get_ptr(solver),
343344
edata=_get_ptr(edata),
344345
)
346+
347+
348+
def restore_model(
349+
module_name: str, module_path: Path, settings: dict, checksum: str = None
350+
) -> amici.Model:
351+
"""
352+
Recreate a model instance with given settings.
353+
354+
For use in ModelPtr.__reduce__.
355+
356+
:param module_name:
357+
Name of the model module.
358+
:param module_path:
359+
Path to the model module.
360+
:param settings:
361+
Model settings to be applied.
362+
See `set_model_settings` / `get_model_settings`.
363+
:param checksum:
364+
Checksum of the model extension to verify integrity.
365+
"""
366+
from . import import_model_module
367+
368+
model_module = import_model_module(module_name, module_path)
369+
model = model_module.get_model()
370+
model.module = model_module._self
371+
set_model_settings(model, settings)
372+
373+
if checksum is not None and checksum != file_checksum(
374+
model.module.extension_path
375+
):
376+
raise RuntimeError(
377+
f"Model file checksum does not match the expected checksum "
378+
f"({checksum}). The model file may have been modified "
379+
f"after the model was pickled."
380+
)
381+
382+
return model
383+
384+
385+
def file_checksum(
386+
path: str | Path, algorithm: str = "sha256", chunk_size: int = 8192
387+
) -> str:
388+
"""
389+
Compute checksum for `path` using `algorithm` (e.g. 'md5', 'sha1', 'sha256').
390+
Returns the hexadecimal digest string.
391+
"""
392+
import hashlib
393+
394+
path = Path(path)
395+
h = hashlib.new(algorithm)
396+
with path.open("rb") as f:
397+
for chunk in iter(lambda: f.read(chunk_size), b""):
398+
h.update(chunk)
399+
return h.hexdigest()

python/tests/test_swig_interface.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55

66
import copy
77
import numbers
8+
import pickle
89
from math import nan
910

1011
import amici
1112
import numpy as np
1213
import pytest
1314
import xarray
15+
from amici import SteadyStateSensitivityMode
16+
from amici.testing import skip_on_valgrind
1417

1518

1619
def test_version_number(pysb_example_presimulation_module):
@@ -664,3 +667,33 @@ def test_reporting_mode_obs_llh(sbml_example_presimulation_module):
664667
assert rdata.ssigmay is None
665668
assert rdata.sllh.size > 0
666669
assert not np.isnan(rdata.sllh).any()
670+
671+
672+
@skip_on_valgrind
673+
def test_pickle_model(sbml_example_presimulation_module):
674+
model_module = sbml_example_presimulation_module
675+
model = model_module.get_model()
676+
677+
assert (
678+
model.get_steady_state_sensitivity_mode()
679+
== SteadyStateSensitivityMode.integrationOnly
680+
)
681+
model.set_steady_state_sensitivity_mode(
682+
SteadyStateSensitivityMode.newtonOnly
683+
)
684+
685+
model_pickled = pickle.loads(pickle.dumps(model))
686+
# ensure it's re-picklable
687+
model_pickled = pickle.loads(pickle.dumps(model_pickled))
688+
assert (
689+
model_pickled.get_steady_state_sensitivity_mode()
690+
== SteadyStateSensitivityMode.newtonOnly
691+
)
692+
693+
model_pickled.set_steady_state_sensitivity_mode(
694+
SteadyStateSensitivityMode.integrateIfNewtonFails
695+
)
696+
assert (
697+
model.get_steady_state_sensitivity_mode()
698+
!= model_pickled.get_steady_state_sensitivity_mode()
699+
)

swig/model.i

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,20 @@ def simulate(
195195
def __deepcopy__(self, memo):
196196
return self.clone()
197197

198+
def __reduce__(self):
199+
from amici.swig_wrappers import restore_model, get_model_settings, file_checksum
200+
201+
return (
202+
restore_model,
203+
(
204+
self.get_name(),
205+
Path(self.module.__spec__.origin).parent,
206+
get_model_settings(self),
207+
file_checksum(self.module.extension_path),
208+
),
209+
{}
210+
)
211+
198212

199213
@overload
200214
def simulate(

swig/modelname.template.i

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@ import sysconfig
88
from pathlib import Path
99
1010
ext_suffix = sysconfig.get_config_var('EXT_SUFFIX')
11+
extension_path = Path(__file__).parent / f'_TPL_MODELNAME{ext_suffix}'
1112
_TPL_MODELNAME = amici._module_from_path(
1213
'TPL_MODELNAME._TPL_MODELNAME' if __package__ or '.' in __name__
1314
else '_TPL_MODELNAME',
14-
Path(__file__).parent / f'_TPL_MODELNAME{ext_suffix}',
15+
extension_path,
1516
)
1617
1718
def _get_import_time():

0 commit comments

Comments
 (0)