22
33from __future__ import annotations
44
5- import logging
6- import pickle
7- from importlib .metadata import version
8- from pathlib import Path
95from typing import TYPE_CHECKING , overload
106
117import cloudpickle
12- from ampform .sympy ._cache import (
13- get_readable_hash , # noqa: PLC2701
14- get_system_cache_directory , # noqa: PLC2701
15- )
8+ from ampform .sympy ._cache import cache_to_disk # noqa: PLC2701
169from ampform .sympy .cached import (
1710 doit , # noqa: F401 # pyright: ignore[reportUnusedImport]
1811 unfold , # noqa: F401 # pyright: ignore[reportUnusedImport]
2215
2316if TYPE_CHECKING :
2417 from collections .abc import Mapping
25- from typing import Any
2618
2719 import sympy as sp
2820 from tensorwaves .function import (
3123 )
3224 from tensorwaves .interface import Function , ParameterValue , ParametrizedFunction
3325
34- _LOGGER = logging .getLogger (__name__ )
35-
3626
3727@overload
38- def lambdify (
39- expr : sp .Expr ,
40- backend : str = "jax" ,
41- directory : str | None = None ,
42- ) -> PositionalArgumentFunction : ...
28+ def lambdify (expr : sp .Expr , * , backend : str = "jax" ) -> PositionalArgumentFunction : ...
4329@overload
4430def lambdify (
4531 expr : sp .Expr ,
4632 parameters : Mapping [sp .Symbol , ParameterValue ],
33+ * ,
4734 backend : str = "jax" ,
48- directory : str | None = None ,
4935) -> ParametrizedBackendFunction : ...
50- def lambdify ( # type:ignore[misc] # pyright:ignore[reportInconsistentOverload]
36+ @cache_to_disk (
37+ dump_function = cloudpickle .dump ,
38+ dependencies = ["cloudpickle" , "ampform" , "jax" , "sympy" ],
39+ )
40+ def lambdify (
5141 expr : sp .Expr ,
5242 parameters : Mapping [sp .Symbol , ParameterValue ] | None = None ,
43+ * ,
5344 backend : str = "jax" ,
54- cache_directory : Path | str | None = None ,
5545) -> ParametrizedFunction | Function :
5646 """Lambdify a SymPy `~sympy.core.expr.Expr` and cache the result to disk.
5747
@@ -65,41 +55,11 @@ def lambdify( # type:ignore[misc] # pyright:ignore[reportInconsistentOverload]
6555 parameters: Specify this argument in order to create a
6656 `~tensorwaves.function.ParametrizedBackendFunction` instead of a
6757 `~tensorwaves.function.PositionalArgumentFunction`.
68- backend: The choice of backend for the created numerical function. **WARNING**:
69- this function has only been tested for :code:`backend="jax"`!
70- directory: The directory in which to cache the result. If `None`, the cache
71- directory will be put under the home directory, or to the path specified by
72- the environment variable :code:`SYMPY_CACHE_DIR`.
58+ backend: The choice of backend for the created numerical function.
59+ **WARNING**: this function has only been tested for :code:`backend="jax"`!
7360
7461 .. seealso:: :func:`ampform.sympy.perform_cached_doit`
7562 """
76- if cache_directory is None :
77- system_cache_dir = get_system_cache_directory ()
78- backend_version = version (backend )
79- cache_directory = (
80- Path (system_cache_dir ) / "ampform_dpd" / f"{ backend } -v{ backend_version } "
81- )
82- if not isinstance (cache_directory , Path ):
83- cache_directory = Path (cache_directory )
84- cache_directory .mkdir (exist_ok = True , parents = True )
85- if parameters is None :
86- hash_obj : Any = expr
87- else :
88- hash_obj = (
89- expr ,
90- tuple ((s , parameters [s ]) for s in sorted (parameters , key = str )),
91- )
92- h = get_readable_hash (hash_obj )
93- filename = cache_directory / f"{ h } .pkl"
94- if filename .exists ():
95- with open (filename , "rb" ) as f :
96- return pickle .load (f )
97- _LOGGER .warning (f"Cached function file { filename } not found, lambdifying..." )
98- func : ParametrizedFunction | Function
9963 if parameters is None :
100- func = create_function (expr , backend )
101- else :
102- func = create_parametrized_function (expr , parameters , backend )
103- with open (filename , "wb" ) as f :
104- cloudpickle .dump (func , f )
105- return func
64+ return create_function (expr , backend )
65+ return create_parametrized_function (expr , parameters , backend )
0 commit comments