1
+ import warnings
1
2
import weakref
3
+ from collections .abc import Callable
4
+ from functools import singledispatch , wraps
2
5
from hashlib import sha256
3
6
from pathlib import Path
7
+ from pickle import dumps
8
+ from tempfile import NamedTemporaryFile
9
+ from typing import Any
4
10
5
11
from numba .core .caching import CacheImpl , _CacheLocator
6
12
7
13
from pytensor import config
8
14
from pytensor .graph .basic import Apply
15
+ from pytensor .link .numba .compile import numba_funcify , numba_njit
9
16
10
17
11
18
NUMBA_PYTENSOR_CACHE_ENABLED = True
@@ -19,8 +26,6 @@ def __init__(self, py_func, py_file, hash):
19
26
self ._py_func = py_func
20
27
self ._py_file = py_file
21
28
self ._hash = hash
22
- # src_hash = hash(pytensor_loader._module_sources[self._py_file])
23
- # self._hash = hash((src_hash, py_file, pytensor.__version__))
24
29
25
30
def ensure_cache_path (self ):
26
31
pass
@@ -74,3 +79,165 @@ def cache_node_key(node: Apply, extra_key="") -> str:
74
79
),
75
80
).encode ()
76
81
).hexdigest ()
82
+
83
+
84
+ @singledispatch
85
+ def numba_funcify_default_op_cache_key (
86
+ op , node = None , ** kwargs
87
+ ) -> Callable | tuple [Callable , Any ]:
88
+ """Funcify an Op and implement a default cache key.
89
+
90
+ The default cache key is based on the op class and its properties.
91
+ It does not take into account the node inputs or other context.
92
+ Note that numba will use the array dtypes, rank and layout as part of the cache key,
93
+ but not the static shape or constant values.
94
+ If the funcify implementation exploits this information, then this method should not be used.
95
+ Instead dispatch directly on `numba_funcify_and_cache_key` (or just numba_funcify)
96
+ which won't use any cache key.
97
+ """
98
+ # Default cache key of None which means "don't try to do directly cache this function"
99
+ raise NotImplementedError ()
100
+
101
+
102
+ def register_funcify_default_op_cache_key (op_type ):
103
+ """Register a funcify implementation for both cache and non-cache versions."""
104
+
105
+ def decorator (dispatch_func ):
106
+ # Register with the cache key dispatcher
107
+ numba_funcify_default_op_cache_key .register (op_type )(dispatch_func )
108
+
109
+ # Create a wrapper for the non-cache dispatcher
110
+ @wraps (dispatch_func )
111
+ def dispatch_func_wrapper (* args , ** kwargs ):
112
+ func , key = dispatch_func (* args , ** kwargs )
113
+ # Discard the key for the non-cache version
114
+ return func
115
+
116
+ # Register the wrapper with the non-cache dispatcher
117
+ numba_funcify .register (op_type )(dispatch_func_wrapper )
118
+
119
+ return dispatch_func
120
+
121
+ return decorator
122
+
123
+
124
+ @singledispatch
125
+ def numba_funcify_and_cache_key (op , node = None , ** kwargs ) -> tuple [Callable , str | None ]:
126
+ # Default cache key of None which means "don't try to do directly cache this function"
127
+ if hasattr (op , "_props" ):
128
+ try :
129
+ func_and_salt = numba_funcify_default_op_cache_key (op , node = node , ** kwargs )
130
+ except NotImplementedError :
131
+ pass
132
+ else :
133
+ if isinstance (func_and_salt , tuple ):
134
+ func , salt = func_and_salt
135
+ else :
136
+ func , salt = func_and_salt , "0"
137
+ props_dict = op ._props_dict ()
138
+ if not props_dict :
139
+ # Simple op, just use the type string as key
140
+ key_bytes = str ((type (op ), salt )).encode ()
141
+ else :
142
+ # Simple props, can use string representation of props as key
143
+ simple_types = (str , bool , int , type (None ), float )
144
+ container_types = (tuple , frozenset )
145
+ if all (
146
+ isinstance (v , simple_types )
147
+ or (
148
+ isinstance (v , container_types )
149
+ and all (isinstance (i , simple_types ) for i in v )
150
+ )
151
+ for v in props_dict .values ()
152
+ ):
153
+ key_bytes = str (
154
+ (type (op ), tuple (props_dict .items ()), salt )
155
+ ).encode ()
156
+ else :
157
+ # Complex props, use pickle to serialize them
158
+ key_bytes = dumps ((str (type (op )), tuple (props_dict .items ()), salt ))
159
+ return func , sha256 (key_bytes ).hexdigest ()
160
+
161
+ # Fallback
162
+ return numba_funcify (op , node = node , ** kwargs ), None
163
+
164
+
165
+ def register_funcify_and_cache_key (op_type ):
166
+ """Register a funcify implementation for both cache and non-cache versions."""
167
+
168
+ def decorator (dispatch_func ):
169
+ # Register with the cache key dispatcher
170
+ numba_funcify_and_cache_key .register (op_type )(dispatch_func )
171
+
172
+ # Create a wrapper for the non-cache dispatcher
173
+ @wraps (dispatch_func )
174
+ def dispatch_func_wrapper (* args , ** kwargs ):
175
+ func , key = dispatch_func (* args , ** kwargs )
176
+ # Discard the key for the non-cache version
177
+ return func
178
+
179
+ # Register the wrapper with the non-cache dispatcher
180
+ numba_funcify .register (op_type )(dispatch_func_wrapper )
181
+
182
+ return dispatch_func_wrapper
183
+
184
+ return decorator
185
+
186
+
187
+ def numba_njit_and_cache (op , * args , ** kwargs ):
188
+ jitable_func , key = numba_funcify_and_cache_key (op , * args , ** kwargs )
189
+
190
+ if key is not None :
191
+ # To force numba to use our cache, we must compile the function so that any closure
192
+ # becomes a global variable...
193
+ op_name = op .__class__ .__name__
194
+ cached_func = compile_numba_function_src (
195
+ src = f"def { op_name } (*args): return jitable_func(*args)" ,
196
+ function_name = op_name ,
197
+ global_env = globals () | {"jitable_func" : jitable_func },
198
+ cache_key = key ,
199
+ )
200
+ return numba_njit (cached_func , final_function = True , cache = True ), key
201
+ else :
202
+ if config .numba__cache and config .compiler_verbose :
203
+ warnings .warn (
204
+ f"Custom numba cache disabled for { op } of type { type (op )} . "
205
+ f"Even if the function is cached by numba, larger graphs using this function cannot be cached.\n "
206
+ "To enable custom caching, register a numba_funcify_and_cache_key implementation for this Op, with a proper cache key."
207
+ )
208
+
209
+ return numba_njit (
210
+ lambda * args : jitable_func (* args ), final_function = True , cache = False
211
+ ), None
212
+
213
+
214
+ def compile_numba_function_src (
215
+ src : str ,
216
+ function_name : str ,
217
+ global_env : dict [Any , Any ] | None = None ,
218
+ local_env : dict [Any , Any ] | None = None ,
219
+ store_to_disk : bool = False ,
220
+ cache_key : str | None = None ,
221
+ ) -> Callable :
222
+ if store_to_disk :
223
+ with NamedTemporaryFile (delete = False ) as f :
224
+ filename = f .name
225
+ f .write (src .encode ())
226
+ else :
227
+ filename = "<string>"
228
+
229
+ if global_env is None :
230
+ global_env = {}
231
+
232
+ if local_env is None :
233
+ local_env = {}
234
+
235
+ mod_code = compile (src , filename , mode = "exec" )
236
+ exec (mod_code , global_env , local_env )
237
+
238
+ res = local_env [function_name ]
239
+ res .__source__ = src # type: ignore
240
+
241
+ if cache_key is not None :
242
+ CACHED_SRC_FUNCTIONS [res ] = cache_key
243
+ return res
0 commit comments