|
1 | 1 | import operator
|
2 | 2 | import sys
|
3 | 3 | import warnings
|
| 4 | +from collections.abc import Callable |
4 | 5 | from functools import singledispatch
|
5 | 6 |
|
6 | 7 | import numba
|
|
18 | 19 | from pytensor.compile.ops import DeepCopyOp
|
19 | 20 | from pytensor.graph.fg import FunctionGraph
|
20 | 21 | from pytensor.ifelse import IfElse
|
| 22 | +from pytensor.link.numba.cache import ( |
| 23 | + cache_node_key, |
| 24 | +) |
21 | 25 | from pytensor.link.numba.compile import (
|
| 26 | + compile_and_cache_numba_function_src, |
22 | 27 | get_numba_type,
|
23 | 28 | numba_njit,
|
24 | 29 | )
|
@@ -208,20 +213,80 @@ def perform(*inputs):
|
208 | 213 | ret = py_perform_return(inputs)
|
209 | 214 | return ret
|
210 | 215 |
|
211 |
| - return perform |
| 216 | + # Assume we can't cache python functions |
| 217 | + return perform, None |
212 | 218 |
|
213 | 219 |
|
214 | 220 | @singledispatch
|
215 |
| -def numba_funcify(op, node=None, storage_map=None, **kwargs): |
| 221 | +def numba_funcify( |
| 222 | + op, node=None, storage_map=None, **kwargs |
| 223 | +) -> Callable | tuple[Callable, str | int | None]: |
216 | 224 | """Generate a numba function for a given op and apply node.
|
217 | 225 |
|
218 | 226 | The resulting function will usually use the `no_cpython_wrapper`
|
219 | 227 | argument in numba, so it can not be called directly from python,
|
220 | 228 | but only from other jit functions.
|
| 229 | +
|
| 230 | + Optionally, the function can return a key that can be used to provide |
| 231 | + extra caching context or to disable caching (by returning `None`). |
| 232 | + When nothing is returned, PyTensor will assume the function can be cached |
| 233 | + based on the op and node signature alone. |
221 | 234 | """
|
222 | 235 | return generate_fallback_impl(op, node, storage_map, **kwargs)
|
223 | 236 |
|
224 | 237 |
|
| 238 | +def numba_funcify_njit(op, node, **kwargs): |
| 239 | + jitable_func_and_key = numba_funcify(op, node=node, **kwargs) |
| 240 | + |
| 241 | + match jitable_func_and_key: |
| 242 | + case Callable(): |
| 243 | + jitable_func = jitable_func_and_key |
| 244 | + key = cache_node_key(node) |
| 245 | + case (Callable(), str() | int()): |
| 246 | + jitable_func, funcify_key = jitable_func_and_key |
| 247 | + key = cache_node_key(node, funcify_key) |
| 248 | + case (Callable(), None): |
| 249 | + # We were explicitly told by the dispatch not to try and cache this function |
| 250 | + jitable_func, key = jitable_func_and_key |
| 251 | + case _: |
| 252 | + raise TypeError( |
| 253 | + f"numpy_funcify should return a callable or a (callable, key) pair, got {jitable_func_and_key}" |
| 254 | + ) |
| 255 | + |
| 256 | + if key is not None: |
| 257 | + # To force numba to use our cache, we must compile the function so that any closure |
| 258 | + # becomes a global variable... |
| 259 | + op_name = op.__class__.__name__ |
| 260 | + cached_func = compile_and_cache_numba_function_src( |
| 261 | + src=f"def {op_name}(*args): return jitable_func(*args)", |
| 262 | + function_name=op_name, |
| 263 | + global_env=globals() | dict(jitable_func=jitable_func), |
| 264 | + cache_key=key, |
| 265 | + ) |
| 266 | + return numba_njit(cached_func, final_function=True, cache=True) |
| 267 | + else: |
| 268 | + return numba_njit( |
| 269 | + lambda *args: jitable_func(*args), final_function=True, cache=False |
| 270 | + ) |
| 271 | + |
| 272 | + |
| 273 | +@numba_funcify.register(FunctionGraph) |
| 274 | +def numba_funcify_FunctionGraph( |
| 275 | + fgraph, |
| 276 | + node=None, |
| 277 | + fgraph_name="numba_funcified_fgraph", |
| 278 | + **kwargs, |
| 279 | +): |
| 280 | + # TODO: Create hash key for whole graph |
| 281 | + return fgraph_to_python( |
| 282 | + fgraph, |
| 283 | + op_conversion_fn=numba_funcify_njit, |
| 284 | + type_conversion_fn=numba_typify, |
| 285 | + fgraph_name=fgraph_name, |
| 286 | + **kwargs, |
| 287 | + ) |
| 288 | + |
| 289 | + |
225 | 290 | @numba_funcify.register(OpFromGraph)
|
226 | 291 | def numba_funcify_OpFromGraph(op, node=None, **kwargs):
|
227 | 292 | _ = kwargs.pop("storage_map", None)
|
@@ -251,23 +316,8 @@ def opfromgraph(*inputs):
|
251 | 316 | def opfromgraph(*inputs):
|
252 | 317 | return fgraph_fn(*inputs)
|
253 | 318 |
|
254 |
| - return opfromgraph |
255 |
| - |
256 |
| - |
257 |
| -@numba_funcify.register(FunctionGraph) |
258 |
| -def numba_funcify_FunctionGraph( |
259 |
| - fgraph, |
260 |
| - node=None, |
261 |
| - fgraph_name="numba_funcified_fgraph", |
262 |
| - **kwargs, |
263 |
| -): |
264 |
| - return fgraph_to_python( |
265 |
| - fgraph, |
266 |
| - numba_funcify, |
267 |
| - type_conversion_fn=numba_typify, |
268 |
| - fgraph_name=fgraph_name, |
269 |
| - **kwargs, |
270 |
| - ) |
| 319 | + # We can't cache this correctly until we can define a key for it |
| 320 | + return opfromgraph, None |
271 | 321 |
|
272 | 322 |
|
273 | 323 | @numba_funcify.register(DeepCopyOp)
|
|
0 commit comments