We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent cb8b8ac commit 2c03ecfCopy full SHA for 2c03ecf
pytensor/link/jax/dispatch/scalar.py
@@ -62,16 +62,14 @@ def check_if_inputs_scalars(node):
62
63
@jax_funcify.register(ScalarOp)
64
def jax_funcify_ScalarOp(op, node, **kwargs):
65
- func_name = op.nfunc_spec[0]
66
-
67
# We dispatch some PyTensor operators to Python operators
68
# whenever the inputs are all scalars.
69
are_inputs_scalars = check_if_inputs_scalars(node)
70
if are_inputs_scalars:
71
elemwise = elemwise_scalar(op)
72
if elemwise is not None:
73
return elemwise
74
+ func_name = op.nfunc_spec[0]
75
if "." in func_name:
76
jnp_func = functools.reduce(getattr, [jax] + func_name.split("."))
77
else:
0 commit comments