|
6 | 6 | import logging |
7 | 7 | import warnings |
8 | 8 | from collections.abc import Sequence |
| 9 | +from pathlib import Path |
9 | 10 | from typing import Any |
10 | 11 |
|
11 | 12 | import amici |
@@ -342,3 +343,57 @@ def _Model__simulate( |
342 | 343 | solver=_get_ptr(solver), |
343 | 344 | edata=_get_ptr(edata), |
344 | 345 | ) |
| 346 | + |
| 347 | + |
| 348 | +def restore_model( |
| 349 | + module_name: str, module_path: Path, settings: dict, checksum: str = None |
| 350 | +) -> amici.Model: |
| 351 | + """ |
| 352 | + Recreate a model instance with given settings. |
| 353 | +
|
| 354 | + For use in ModelPtr.__reduce__. |
| 355 | +
|
| 356 | + :param module_name: |
| 357 | + Name of the model module. |
| 358 | + :param module_path: |
| 359 | + Path to the model module. |
| 360 | + :param settings: |
| 361 | + Model settings to be applied. |
| 362 | + See `set_model_settings` / `get_model_settings`. |
| 363 | + :param checksum: |
| 364 | + Checksum of the model extension to verify integrity. |
| 365 | + """ |
| 366 | + from . import import_model_module |
| 367 | + |
| 368 | + model_module = import_model_module(module_name, module_path) |
| 369 | + model = model_module.get_model() |
| 370 | + model.module = model_module._self |
| 371 | + set_model_settings(model, settings) |
| 372 | + |
| 373 | + if checksum is not None and checksum != file_checksum( |
| 374 | + model.module.extension_path |
| 375 | + ): |
| 376 | + raise RuntimeError( |
| 377 | + f"Model file checksum does not match the expected checksum " |
| 378 | + f"({checksum}). The model file may have been modified " |
| 379 | + f"after the model was pickled." |
| 380 | + ) |
| 381 | + |
| 382 | + return model |
| 383 | + |
| 384 | + |
| 385 | +def file_checksum( |
| 386 | + path: str | Path, algorithm: str = "sha256", chunk_size: int = 8192 |
| 387 | +) -> str: |
| 388 | + """ |
| 389 | + Compute checksum for `path` using `algorithm` (e.g. 'md5', 'sha1', 'sha256'). |
| 390 | + Returns the hexadecimal digest string. |
| 391 | + """ |
| 392 | + import hashlib |
| 393 | + |
| 394 | + path = Path(path) |
| 395 | + h = hashlib.new(algorithm) |
| 396 | + with path.open("rb") as f: |
| 397 | + for chunk in iter(lambda: f.read(chunk_size), b""): |
| 398 | + h.update(chunk) |
| 399 | + return h.hexdigest() |
0 commit comments