Skip to content

Commit a1fcb77

Browse files
committed
Cleanup JAX Scalar dispatch
1 parent 2c03ecf commit a1fcb77

File tree

2 files changed

+54
-39
lines changed

2 files changed

+54
-39
lines changed

pytensor/link/jax/dispatch/scalar.py

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Ca
3737
return typing.cast(Callable, getattr(tfp_jax_math, jax_op_name))
3838

3939

40-
def check_if_inputs_scalars(node):
40+
def all_inputs_are_scalar(node):
4141
"""Check whether all the inputs of an `Elemwise` are scalar values.
4242
4343
`jax.lax` or `jax.numpy` functions systematically return `TracedArrays`,
@@ -62,54 +62,68 @@ def check_if_inputs_scalars(node):
6262

6363
@jax_funcify.register(ScalarOp)
6464
def jax_funcify_ScalarOp(op, node, **kwargs):
65+
"""Return JAX function that implements the same computation as the Scalar Op.
66+
67+
This dispatch is expected to return a JAX function that works on Array inputs as Elemwise does,
68+
even though it's dispatched on the Scalar Op.
69+
"""
70+
6571
# We dispatch some PyTensor operators to Python operators
6672
# whenever the inputs are all scalars.
67-
are_inputs_scalars = check_if_inputs_scalars(node)
68-
if are_inputs_scalars:
69-
elemwise = elemwise_scalar(op)
70-
if elemwise is not None:
71-
return elemwise
72-
func_name = op.nfunc_spec[0]
73+
if all_inputs_are_scalar(node):
74+
jax_func = jax_funcify_scalar_op_via_py_operators(op)
75+
if jax_func is not None:
76+
return jax_func
77+
78+
nfunc_spec = getattr(op, "nfunc_spec", None)
79+
if nfunc_spec is None:
80+
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
81+
82+
func_name = nfunc_spec[0]
7383
if "." in func_name:
74-
jnp_func = functools.reduce(getattr, [jax] + func_name.split("."))
75-
else:
76-
jnp_func = getattr(jnp, func_name)
77-
78-
if hasattr(op, "nfunc_variadic"):
79-
# These are special cases that handle invalid arities due to the broken
80-
# PyTensor `Op` type contract (e.g. binary `Op`s that also function as
81-
# their own variadic counterparts--even when those counterparts already
82-
# exist as independent `Op`s).
83-
jax_variadic_func = getattr(jnp, op.nfunc_variadic)
84-
85-
def elemwise(*args):
86-
if len(args) > op.nfunc_spec[1]:
87-
return jax_variadic_func(
88-
jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0
89-
)
90-
else:
91-
return jnp_func(*args)
92-
93-
return elemwise
84+
jax_func = functools.reduce(getattr, [jax] + func_name.split("."))
9485
else:
95-
return jnp_func
86+
jax_func = getattr(jnp, func_name)
87+
88+
if len(node.inputs) > op.nfunc_spec[1]:
89+
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
90+
# even though the base Op from `func_name` is specified as a binary Op.
91+
# This happens with `Add`, which can work as a `Sum` for multiple scalars.
92+
jax_variadic_func = getattr(jnp, op.nfunc_variadic, None)
93+
if not jax_variadic_func:
94+
raise NotImplementedError(
95+
f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs"
96+
)
97+
98+
def jax_func(*args):
99+
return jax_variadic_func(
100+
jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0
101+
)
102+
103+
return jax_func
96104

97105

98106
@functools.singledispatch
99-
def elemwise_scalar(op):
107+
def jax_funcify_scalar_op_via_py_operators(op):
108+
"""Specialized JAX dispatch for Elemwise operations where all inputs are Scalar arrays.
109+
110+
Scalar (constant) arrays in the JAX backend get lowered to the native types (int, floats),
111+
which can perform better with Python operators, and more importantly, avoid upcasting to array types
112+
not supported by some JAX functions.
113+
"""
100114
return None
101115

102116

103-
@elemwise_scalar.register(Add)
104-
def elemwise_scalar_add(op):
117+
@jax_funcify_scalar_op_via_py_operators.register(Add)
118+
def jax_funcify_scalar_Add(op):
105119
def elemwise(*inputs):
106120
return sum(inputs)
107121

108122
return elemwise
109123

110124

111-
@elemwise_scalar.register(Mul)
112-
def elemwise_scalar_mul(op):
125+
@jax_funcify_scalar_op_via_py_operators.register(Mul)
126+
def jax_funcify_scalar_Mul(op):
113127
import operator
114128
from functools import reduce
115129

@@ -119,24 +133,24 @@ def elemwise(*inputs):
119133
return elemwise
120134

121135

122-
@elemwise_scalar.register(Sub)
123-
def elemwise_scalar_sub(op):
136+
@jax_funcify_scalar_op_via_py_operators.register(Sub)
137+
def jax_funcify_scalar_Sub(op):
124138
def elemwise(x, y):
125139
return x - y
126140

127141
return elemwise
128142

129143

130-
@elemwise_scalar.register(IntDiv)
131-
def elemwise_scalar_intdiv(op):
144+
@jax_funcify_scalar_op_via_py_operators.register(IntDiv)
145+
def jax_funcify_scalar_IntDiv(op):
132146
def elemwise(x, y):
133147
return x // y
134148

135149
return elemwise
136150

137151

138-
@elemwise_scalar.register(Mod)
139-
def elemwise_scalar_mod(op):
152+
@jax_funcify_scalar_op_via_py_operators.register(Mod)
153+
def jax_funcify_scalar_Mod(op):
140154
def elemwise(x, y):
141155
return x % y
142156

tests/link/jax/test_scalar.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
psi,
2424
sigmoid,
2525
softplus,
26+
tri_gamma,
2627
)
2728
from pytensor.tensor.type import matrix, scalar, vector
2829
from tests.link.jax.test_basic import compare_jax_and_py

0 commit comments

Comments
 (0)