Skip to content

Commit 2c03ecf

Browse files
authored
Defer the use of nfunc_spec in JAX scalar dispatch
1 parent cb8b8ac commit 2c03ecf

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

pytensor/link/jax/dispatch/scalar.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,14 @@ def check_if_inputs_scalars(node):
6262

6363
@jax_funcify.register(ScalarOp)
6464
def jax_funcify_ScalarOp(op, node, **kwargs):
65-
func_name = op.nfunc_spec[0]
66-
6765
# We dispatch some PyTensor operators to Python operators
6866
# whenever the inputs are all scalars.
6967
are_inputs_scalars = check_if_inputs_scalars(node)
7068
if are_inputs_scalars:
7169
elemwise = elemwise_scalar(op)
7270
if elemwise is not None:
7371
return elemwise
74-
72+
func_name = op.nfunc_spec[0]
7573
if "." in func_name:
7674
jnp_func = functools.reduce(getattr, [jax] + func_name.split("."))
7775
else:

0 commit comments

Comments
 (0)