Skip to content

Commit b86e0d6

Browse files
authored
ENH: improve hash behavior with cache_to_disk (#157)
* BREAK: make `backend` argument keyword only * ENH: embed `cloudpickle` and `ampform` as `lambdify` dependencies * ENH: embed jax and sympy version in cache dependency * MAINT: improve docstring positioning
1 parent aa43055 commit b86e0d6

File tree

1 file changed

+13
-53
lines changed

1 file changed

+13
-53
lines changed

src/ampform_dpd/io/cached.py

Lines changed: 13 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,10 @@
22

33
from __future__ import annotations
44

5-
import logging
6-
import pickle
7-
from importlib.metadata import version
8-
from pathlib import Path
95
from typing import TYPE_CHECKING, overload
106

117
import cloudpickle
12-
from ampform.sympy._cache import (
13-
get_readable_hash, # noqa: PLC2701
14-
get_system_cache_directory, # noqa: PLC2701
15-
)
8+
from ampform.sympy._cache import cache_to_disk # noqa: PLC2701
169
from ampform.sympy.cached import (
1710
doit, # noqa: F401 # pyright: ignore[reportUnusedImport]
1811
unfold, # noqa: F401 # pyright: ignore[reportUnusedImport]
@@ -22,7 +15,6 @@
2215

2316
if TYPE_CHECKING:
2417
from collections.abc import Mapping
25-
from typing import Any
2618

2719
import sympy as sp
2820
from tensorwaves.function import (
@@ -31,27 +23,25 @@
3123
)
3224
from tensorwaves.interface import Function, ParameterValue, ParametrizedFunction
3325

34-
_LOGGER = logging.getLogger(__name__)
35-
3626

3727
@overload
38-
def lambdify(
39-
expr: sp.Expr,
40-
backend: str = "jax",
41-
directory: str | None = None,
42-
) -> PositionalArgumentFunction: ...
28+
def lambdify(expr: sp.Expr, *, backend: str = "jax") -> PositionalArgumentFunction: ...
4329
@overload
4430
def lambdify(
4531
expr: sp.Expr,
4632
parameters: Mapping[sp.Symbol, ParameterValue],
33+
*,
4734
backend: str = "jax",
48-
directory: str | None = None,
4935
) -> ParametrizedBackendFunction: ...
50-
def lambdify( # type:ignore[misc] # pyright:ignore[reportInconsistentOverload]
36+
@cache_to_disk(
37+
dump_function=cloudpickle.dump,
38+
dependencies=["cloudpickle", "ampform", "jax", "sympy"],
39+
)
40+
def lambdify(
5141
expr: sp.Expr,
5242
parameters: Mapping[sp.Symbol, ParameterValue] | None = None,
43+
*,
5344
backend: str = "jax",
54-
cache_directory: Path | str | None = None,
5545
) -> ParametrizedFunction | Function:
5646
"""Lambdify a SymPy `~sympy.core.expr.Expr` and cache the result to disk.
5747
@@ -65,41 +55,11 @@ def lambdify( # type:ignore[misc] # pyright:ignore[reportInconsistentOverload]
6555
parameters: Specify this argument in order to create a
6656
`~tensorwaves.function.ParametrizedBackendFunction` instead of a
6757
`~tensorwaves.function.PositionalArgumentFunction`.
68-
backend: The choice of backend for the created numerical function. **WARNING**:
69-
this function has only been tested for :code:`backend="jax"`!
70-
directory: The directory in which to cache the result. If `None`, the cache
71-
directory will be put under the home directory, or to the path specified by
72-
the environment variable :code:`SYMPY_CACHE_DIR`.
58+
backend: The choice of backend for the created numerical function.
59+
**WARNING**: this function has only been tested for :code:`backend="jax"`!
7360
7461
.. seealso:: :func:`ampform.sympy.perform_cached_doit`
7562
"""
76-
if cache_directory is None:
77-
system_cache_dir = get_system_cache_directory()
78-
backend_version = version(backend)
79-
cache_directory = (
80-
Path(system_cache_dir) / "ampform_dpd" / f"{backend}-v{backend_version}"
81-
)
82-
if not isinstance(cache_directory, Path):
83-
cache_directory = Path(cache_directory)
84-
cache_directory.mkdir(exist_ok=True, parents=True)
85-
if parameters is None:
86-
hash_obj: Any = expr
87-
else:
88-
hash_obj = (
89-
expr,
90-
tuple((s, parameters[s]) for s in sorted(parameters, key=str)),
91-
)
92-
h = get_readable_hash(hash_obj)
93-
filename = cache_directory / f"{h}.pkl"
94-
if filename.exists():
95-
with open(filename, "rb") as f:
96-
return pickle.load(f)
97-
_LOGGER.warning(f"Cached function file {filename} not found, lambdifying...")
98-
func: ParametrizedFunction | Function
9963
if parameters is None:
100-
func = create_function(expr, backend)
101-
else:
102-
func = create_parametrized_function(expr, parameters, backend)
103-
with open(filename, "wb") as f:
104-
cloudpickle.dump(func, f)
105-
return func
64+
return create_function(expr, backend)
65+
return create_parametrized_function(expr, parameters, backend)

0 commit comments

Comments
 (0)