@@ -37,7 +37,7 @@ def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Ca
37
37
return typing .cast (Callable , getattr (tfp_jax_math , jax_op_name ))
38
38
39
39
40
- def check_if_inputs_scalars (node ):
40
+ def all_inputs_are_scalar (node ):
41
41
"""Check whether all the inputs of an `Elemwise` are scalar values.
42
42
43
43
`jax.lax` or `jax.numpy` functions systematically return `TracedArrays`,
@@ -62,54 +62,68 @@ def check_if_inputs_scalars(node):
62
62
63
63
@jax_funcify .register (ScalarOp )
64
64
def jax_funcify_ScalarOp (op , node , ** kwargs ):
65
+ """Return JAX function that implements the same computation as the Scalar Op.
66
+
67
+ This dispatch is expected to return a JAX function that works on Array inputs as Elemwise does,
68
+ even though it's dispatched on the Scalar Op.
69
+ """
70
+
65
71
# We dispatch some PyTensor operators to Python operators
66
72
# whenever the inputs are all scalars.
67
- are_inputs_scalars = check_if_inputs_scalars (node )
68
- if are_inputs_scalars :
69
- elemwise = elemwise_scalar (op )
70
- if elemwise is not None :
71
- return elemwise
72
- func_name = op .nfunc_spec [0 ]
73
+ if all_inputs_are_scalar (node ):
74
+ jax_func = jax_funcify_scalar_op_via_py_operators (op )
75
+ if jax_func is not None :
76
+ return jax_func
77
+
78
+ nfunc_spec = getattr (op , "nfunc_spec" , None )
79
+ if nfunc_spec is None :
80
+ raise NotImplementedError (f"Dispatch not implemented for Scalar Op { op } " )
81
+
82
+ func_name = nfunc_spec [0 ]
73
83
if "." in func_name :
74
- jnp_func = functools .reduce (getattr , [jax ] + func_name .split ("." ))
75
- else :
76
- jnp_func = getattr (jnp , func_name )
77
-
78
- if hasattr (op , "nfunc_variadic" ):
79
- # These are special cases that handle invalid arities due to the broken
80
- # PyTensor `Op` type contract (e.g. binary `Op`s that also function as
81
- # their own variadic counterparts--even when those counterparts already
82
- # exist as independent `Op`s).
83
- jax_variadic_func = getattr (jnp , op .nfunc_variadic )
84
-
85
- def elemwise (* args ):
86
- if len (args ) > op .nfunc_spec [1 ]:
87
- return jax_variadic_func (
88
- jnp .stack (jnp .broadcast_arrays (* args ), axis = 0 ), axis = 0
89
- )
90
- else :
91
- return jnp_func (* args )
92
-
93
- return elemwise
84
+ jax_func = functools .reduce (getattr , [jax ] + func_name .split ("." ))
94
85
else :
95
- return jnp_func
86
+ jax_func = getattr (jnp , func_name )
87
+
88
+ if len (node .inputs ) > op .nfunc_spec [1 ]:
89
+ # Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
90
+ # even though the base Op from `func_name` is specified as a binary Op.
91
+ # This happens with `Add`, which can work as a `Sum` for multiple scalars.
92
+ jax_variadic_func = getattr (jnp , op .nfunc_variadic , None )
93
+ if not jax_variadic_func :
94
+ raise NotImplementedError (
95
+ f"Dispatch not implemented for Scalar Op { op } with { len (node .inputs )} inputs"
96
+ )
97
+
98
+ def jax_func (* args ):
99
+ return jax_variadic_func (
100
+ jnp .stack (jnp .broadcast_arrays (* args ), axis = 0 ), axis = 0
101
+ )
102
+
103
+ return jax_func
96
104
97
105
98
106
@functools .singledispatch
99
- def elemwise_scalar (op ):
107
+ def jax_funcify_scalar_op_via_py_operators (op ):
108
+ """Specialized JAX dispatch for Elemwise operations where all inputs are Scalar arrays.
109
+
110
+ Scalar (constant) arrays in the JAX backend get lowered to the native types (int, floats),
111
+ which can perform better with Python operators, and more importantly, avoid upcasting to array types
112
+ not supported by some JAX functions.
113
+ """
100
114
return None
101
115
102
116
103
- @elemwise_scalar .register (Add )
104
- def elemwise_scalar_add (op ):
117
+ @jax_funcify_scalar_op_via_py_operators .register (Add )
118
+ def jax_funcify_scalar_Add (op ):
105
119
def elemwise (* inputs ):
106
120
return sum (inputs )
107
121
108
122
return elemwise
109
123
110
124
111
- @elemwise_scalar .register (Mul )
112
- def elemwise_scalar_mul (op ):
125
+ @jax_funcify_scalar_op_via_py_operators .register (Mul )
126
+ def jax_funcify_scalar_Mul (op ):
113
127
import operator
114
128
from functools import reduce
115
129
@@ -119,24 +133,24 @@ def elemwise(*inputs):
119
133
return elemwise
120
134
121
135
122
- @elemwise_scalar .register (Sub )
123
- def elemwise_scalar_sub (op ):
136
+ @jax_funcify_scalar_op_via_py_operators .register (Sub )
137
+ def jax_funcify_scalar_Sub (op ):
124
138
def elemwise (x , y ):
125
139
return x - y
126
140
127
141
return elemwise
128
142
129
143
130
- @elemwise_scalar .register (IntDiv )
131
- def elemwise_scalar_intdiv (op ):
144
+ @jax_funcify_scalar_op_via_py_operators .register (IntDiv )
145
+ def jax_funcify_scalar_IntDiv (op ):
132
146
def elemwise (x , y ):
133
147
return x // y
134
148
135
149
return elemwise
136
150
137
151
138
- @elemwise_scalar .register (Mod )
139
- def elemwise_scalar_mod (op ):
152
+ @jax_funcify_scalar_op_via_py_operators .register (Mod )
153
+ def jax_funcify_scalar_Mod (op ):
140
154
def elemwise (x , y ):
141
155
return x % y
142
156
0 commit comments