Skip to content

Commit 87ecb5f

Browse files
committed
Allow lazy import of linker dispatchers
1 parent 60811aa commit 87ecb5f

File tree

9 files changed

+59
-46
lines changed

9 files changed

+59
-46
lines changed

pytensor/link/jax/dispatch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# isort: off
2-
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
2+
from pytensor.link.jax.linker import jax_funcify, jax_typify
33

44
# Load dispatch specializations
55
import pytensor.link.jax.dispatch.blas

pytensor/link/jax/dispatch/basic.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import warnings
22
from collections.abc import Callable
3-
from functools import singledispatch
43

54
import jax
65
import jax.numpy as jnp
@@ -10,8 +9,10 @@
109
from pytensor.compile.builders import OpFromGraph
1110
from pytensor.compile.ops import DeepCopyOp, ViewOp
1211
from pytensor.configdefaults import config
12+
from pytensor.graph import Op
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.ifelse import IfElse
15+
from pytensor.link.jax.linker import jax_funcify, jax_typify
1516
from pytensor.link.utils import fgraph_to_python
1617
from pytensor.raise_op import Assert, CheckAndRaise
1718

@@ -22,24 +23,15 @@
2223
jax.config.update("jax_enable_x64", False)
2324

2425

25-
@singledispatch
26-
def jax_typify(data, dtype=None, **kwargs):
27-
r"""Convert instances of PyTensor `Type`\s to JAX types."""
28-
if dtype is None:
29-
return data
30-
else:
31-
return jnp.array(data, dtype=dtype)
32-
33-
3426
@jax_typify.register(np.ndarray)
3527
def jax_typify_ndarray(data, dtype=None, **kwargs):
3628
if len(data.shape) == 0:
3729
return data.item()
3830
return jnp.array(data, dtype=dtype)
3931

4032

41-
@singledispatch
42-
def jax_funcify(op, node=None, storage_map=None, **kwargs):
33+
@jax_funcify.register(Op)
34+
def jax_funcify_op(op, node=None, storage_map=None, **kwargs):
4335
"""Create a JAX compatible function from an PyTensor `Op`."""
4436
raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
4537

pytensor/link/jax/linker.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,29 @@
11
import warnings
2+
from functools import singledispatch
23

34
from numpy.random import Generator
45

56
from pytensor.compile.sharedvalue import SharedVariable, shared
67
from pytensor.link.basic import JITLinker
78

89

10+
@singledispatch
11+
def jax_typify(data, dtype=None, **kwargs):
12+
r"""Convert instances of PyTensor `Type`\s to JAX types."""
13+
import jax.numpy as jnp
14+
15+
if dtype is None:
16+
return data
17+
else:
18+
return jnp.array(data, dtype=dtype)
19+
20+
21+
@singledispatch
22+
def jax_funcify(obj, *args, **kwargs):
23+
"""Create a JAX compatible function from an PyTensor `Op`."""
24+
raise NotImplementedError(f"No JAX conversion for the given type: {type(obj)}")
25+
26+
927
class JAXLinker(JITLinker):
1028
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""
1129

@@ -14,7 +32,6 @@ def __init__(self, *args, **kwargs):
1432
super().__init__(*args, **kwargs)
1533

1634
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
17-
from pytensor.link.jax.dispatch import jax_funcify
1835
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
1936
from pytensor.tensor.random.type import RandomType
2037

@@ -111,8 +128,6 @@ def convert_scalar_shape_inputs(
111128
return convert_scalar_shape_inputs
112129

113130
def create_thunk_inputs(self, storage_map):
114-
from pytensor.link.jax.dispatch import jax_typify
115-
116131
thunk_inputs = []
117132
for n in self.fgraph.inputs:
118133
sinput = storage_map[n]

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# isort: off
2-
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify
2+
from pytensor.link.numba.linker import numba_funcify, numba_typify
33

44
# Load dispatch specializations
55
import pytensor.link.numba.dispatch.blockwise

pytensor/link/numba/dispatch/basic.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import sys
33
import warnings
44
from copy import copy
5-
from functools import singledispatch
65
from textwrap import dedent
76

87
import numba
@@ -21,11 +20,13 @@
2120
from pytensor.compile.builders import OpFromGraph
2221
from pytensor.compile.function.types import add_supervisor_to_fgraph
2322
from pytensor.compile.ops import DeepCopyOp
23+
from pytensor.graph import Op
2424
from pytensor.graph.basic import Apply
2525
from pytensor.graph.fg import FunctionGraph
2626
from pytensor.graph.type import Type
2727
from pytensor.ifelse import IfElse
2828
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
29+
from pytensor.link.numba.linker import numba_funcify, numba_typify
2930
from pytensor.link.utils import (
3031
compile_function_src,
3132
fgraph_to_python,
@@ -276,11 +277,6 @@ def create_arg_string(x):
276277
return args
277278

278279

279-
@singledispatch
280-
def numba_typify(data, dtype=None, **kwargs):
281-
return data
282-
283-
284280
def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
285281
"""Create a Numba compatible function from a Pytensor `Op`."""
286282

@@ -326,14 +322,8 @@ def perform(*inputs):
326322
return perform
327323

328324

329-
@singledispatch
330-
def numba_funcify(op, node=None, storage_map=None, **kwargs):
331-
"""Generate a numba function for a given op and apply node.
332-
333-
The resulting function will usually use the `no_cpython_wrapper`
334-
argument in numba, so it can not be called directly from python,
335-
but only from other jit functions.
336-
"""
325+
@numba_funcify.register(Op)
326+
def numba_funcify_op(op, node=None, storage_map=None, **kwargs):
337327
return generate_fallback_impl(op, node, storage_map, **kwargs)
338328

339329

pytensor/link/numba/linker.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
1+
from functools import singledispatch
2+
13
from pytensor.link.basic import JITLinker
24

35

6+
@singledispatch
7+
def numba_typify(data, dtype=None, **kwargs):
8+
raise NotImplementedError(
9+
f"Numba funcify not implemented for data type {type(data)}"
10+
)
11+
12+
13+
@singledispatch
14+
def numba_funcify(obj, *args, **kwargs):
15+
raise NotImplementedError(f"Numba funcify not implemented for type {type(obj)}")
16+
17+
418
class NumbaLinker(JITLinker):
519
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""
620

721
def fgraph_convert(self, fgraph, **kwargs):
8-
from pytensor.link.numba.dispatch import numba_funcify
9-
1022
return numba_funcify(fgraph, **kwargs)
1123

1224
def jit_compile(self, fn):

pytensor/link/pytorch/dispatch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# isort: off
2-
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
2+
from pytensor.link.pytorch.linker import pytorch_funcify, pytorch_typify
33

44
# # Load dispatch specializations
55
import pytensor.link.pytorch.dispatch.blas

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from functools import singledispatch
21
from types import NoneType
32

43
import numpy as np
@@ -10,9 +9,11 @@
109
from pytensor.compile.builders import OpFromGraph
1110
from pytensor.compile.function.types import add_supervisor_to_fgraph
1211
from pytensor.compile.ops import DeepCopyOp
12+
from pytensor.graph import Op
1313
from pytensor.graph.basic import Constant
1414
from pytensor.graph.fg import FunctionGraph
1515
from pytensor.ifelse import IfElse
16+
from pytensor.link.pytorch.linker import pytorch_funcify, pytorch_typify
1617
from pytensor.link.utils import fgraph_to_python
1718
from pytensor.raise_op import CheckAndRaise
1819
from pytensor.tensor.basic import (
@@ -27,11 +28,6 @@
2728
)
2829

2930

30-
@singledispatch
31-
def pytorch_typify(data, **kwargs):
32-
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
33-
34-
3531
@pytorch_typify.register(np.ndarray)
3632
@pytorch_typify.register(torch.Tensor)
3733
def pytorch_typify_tensor(data, dtype=None, **kwargs):
@@ -45,8 +41,8 @@ def pytorch_typify_no_conversion_needed(data, **kwargs):
4541
return data
4642

4743

48-
@singledispatch
49-
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
44+
@pytorch_funcify.register(Op)
45+
def pytorch_funcify_op(op, node=None, storage_map=None, **kwargs):
5046
"""Create a PyTorch compatible function from an PyTensor `Op`."""
5147
raise NotImplementedError(
5248
f"No PyTorch conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/821` for progress or to request we prioritize this operation"

pytensor/link/pytorch/linker.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
1+
from functools import singledispatch
2+
13
from pytensor.link.basic import JITLinker
24
from pytensor.link.utils import unique_name_generator
35

46

7+
@singledispatch
8+
def pytorch_typify(data, **kwargs):
9+
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
10+
11+
12+
@singledispatch
13+
def pytorch_funcify(obj, *args, **kwargs):
14+
raise NotImplementedError(f"pytorch_funcify is not implemented for {type(obj)}")
15+
16+
517
class PytorchLinker(JITLinker):
618
"""A `Linker` that compiles NumPy-based operations using torch.compile."""
719

@@ -10,8 +22,6 @@ def __init__(self, *args, **kwargs):
1022
self.gen_functors = []
1123

1224
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
13-
from pytensor.link.pytorch.dispatch import pytorch_funcify
14-
1525
# We want to have globally unique names
1626
# across the entire pytensor graph, not
1727
# just the subgraph
@@ -40,8 +50,6 @@ def jit_compile(self, fn):
4050
# flag that tend to help our graphs
4151
torch._dynamo.config.capture_dynamic_output_shape_ops = True
4252

43-
from pytensor.link.pytorch.dispatch import pytorch_typify
44-
4553
class wrapper:
4654
"""
4755
Pytorch would fail compiling our method when trying

0 commit comments

Comments
 (0)