Skip to content

Commit b6fbacd

Browse files
committed
Numba dispatch of ScalarLoop
1 parent feafa7e commit b6fbacd

File tree

3 files changed

+141
-11
lines changed

3 files changed

+141
-11
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_name_for_object,
1717
unique_name_generator,
1818
)
19+
from pytensor.scalar import ScalarLoop
1920
from pytensor.scalar.basic import (
2021
Add,
2122
Cast,
@@ -69,7 +70,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
6970
scalar_func_numba = wrap_cython_function(
7071
cython_func, output_dtype, input_dtypes
7172
)
72-
has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch
73+
has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch()
7374
input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes()
7475
output_inner_dtype = scalar_func_numba.numpy_output_dtype()
7576

@@ -331,3 +332,46 @@ def softplus(x):
331332
return numba_basic.direct_cast(value, out_dtype)
332333

333334
return softplus
335+
336+
337+
@numba_funcify.register(ScalarLoop)
338+
def numba_funcify_ScalarLoop(op, node, **kwargs):
339+
inner_fn = numba_basic.numba_njit(numba_funcify(op.fgraph))
340+
341+
if op.is_while:
342+
n_update = len(op.outputs) - 1
343+
344+
@numba_basic.numba_njit
345+
def while_loop(n_steps, *inputs):
346+
carry, constant = inputs[:n_update], inputs[n_update:]
347+
348+
until = False
349+
for i in range(n_steps):
350+
outputs = inner_fn(*carry, *constant)
351+
carry, until = outputs[:-1], outputs[-1]
352+
if until:
353+
break
354+
355+
return *carry, until
356+
357+
return while_loop
358+
359+
else:
360+
n_update = len(op.outputs)
361+
362+
@numba_basic.numba_njit
363+
def for_loop(n_steps, *inputs):
364+
carry, constant = inputs[:n_update], inputs[n_update:]
365+
366+
if n_steps < 0:
367+
raise ValueError("ScalarLoop does not have a termination condition.")
368+
369+
for i in range(n_steps):
370+
carry = inner_fn(*carry, *constant)
371+
372+
if n_update == 1:
373+
return carry[0]
374+
else:
375+
return carry
376+
377+
return for_loop

tests/link/numba/test_elemwise.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -586,18 +586,43 @@ def test_elemwise_multiple_inplace_outs():
586586

587587

588588
def test_scalar_loop():
589-
a = float64("a")
590-
scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a])
589+
a_scalar = float64("a")
590+
const_scalar = float64("const")
591+
scalar_loop = pytensor.scalar.ScalarLoop(
592+
init=[a_scalar],
593+
update=[a_scalar + a_scalar + const_scalar],
594+
constant=[const_scalar],
595+
)
591596

592-
x = pt.tensor("x", shape=(3,))
593-
elemwise_loop = Elemwise(scalar_loop)(3, x)
597+
a = pt.tensor("a", shape=(3,))
598+
const = pt.tensor("const", shape=(3,))
599+
n_steps = 3
600+
elemwise_loop = Elemwise(scalar_loop)(n_steps, a, const)
594601

595-
with pytest.warns(UserWarning, match="object mode"):
596-
compare_numba_and_py(
597-
[x],
598-
[elemwise_loop],
599-
(np.array([1, 2, 3], dtype="float64"),),
600-
)
602+
compare_numba_and_py(
603+
[a, const],
604+
[elemwise_loop],
605+
[np.array([1, 2, 3], dtype="float64"), np.array([1, 1, 1], dtype="float64")],
606+
)
607+
608+
609+
def test_gammainc_wrt_k_grad():
610+
x = pt.vector("x", dtype="float64")
611+
k = pt.vector("k", dtype="float64")
612+
613+
out = pt.gammainc(k, x)
614+
grad_out = grad(out.sum(), k)
615+
616+
compare_numba_and_py(
617+
[x, k],
618+
[grad_out],
619+
# These values of x and k trigger all the branches in the gradient of gammainc
620+
[
621+
np.array([0.0, 29.0, 31.0], dtype="float64"),
622+
np.array([1.0, 13.0, 11.0], dtype="float64"),
623+
],
624+
eval_obj_mode=False,
625+
)
601626

602627

603628
class TestsBenchmark:

tests/link/numba/test_scalar.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytensor.scalar.math as psm
77
import pytensor.tensor as pt
88
from pytensor import config, function
9+
from pytensor.scalar import ScalarLoop
910
from pytensor.scalar.basic import Composite
1011
from pytensor.tensor import tensor
1112
from pytensor.tensor.elemwise import Elemwise
@@ -184,3 +185,63 @@ def test_Softplus(dtype):
184185
strict=True,
185186
err_msg=f"Failed for value {value}",
186187
)
188+
189+
190+
class TestScalarLoop:
191+
def test_scalar_for_loop_single_out(self):
192+
n_steps = ps.int64("n_steps")
193+
x0 = ps.float64("x0")
194+
const = ps.float64("const")
195+
x = x0 + const
196+
197+
op = ScalarLoop(init=[x0], constant=[const], update=[x])
198+
x = op(n_steps, x0, const)
199+
200+
fn = function([n_steps, x0, const], [x], mode=numba_mode)
201+
202+
res_x = fn(n_steps=5, x0=0, const=1)
203+
np.testing.assert_allclose(res_x, 5)
204+
205+
res_x = fn(n_steps=5, x0=0, const=2)
206+
np.testing.assert_allclose(res_x, 10)
207+
208+
res_x = fn(n_steps=4, x0=3, const=-1)
209+
np.testing.assert_allclose(res_x, -1)
210+
211+
def test_scalar_for_loop_multiple_outs(self):
212+
n_steps = ps.int64("n_steps")
213+
x0 = ps.float64("x0")
214+
y0 = ps.int64("y0")
215+
const = ps.float64("const")
216+
x = x0 + const
217+
y = y0 + 1
218+
219+
op = ScalarLoop(init=[x0, y0], constant=[const], update=[x, y])
220+
x, y = op(n_steps, x0, y0, const)
221+
222+
fn = function([n_steps, x0, y0, const], [x, y], mode=numba_mode)
223+
224+
res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=1)
225+
np.testing.assert_allclose(res_x, 5)
226+
np.testing.assert_allclose(res_y, 5)
227+
228+
res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=2)
229+
np.testing.assert_allclose(res_x, 10)
230+
np.testing.assert_allclose(res_y, 5)
231+
232+
res_x, res_y = fn(n_steps=4, x0=3, y0=2, const=-1)
233+
np.testing.assert_allclose(res_x, -1)
234+
np.testing.assert_allclose(res_y, 6)
235+
236+
def test_scalar_while_loop(self):
237+
n_steps = ps.int64("n_steps")
238+
x0 = ps.float64("x0")
239+
x = x0 + 1
240+
until = x >= 10
241+
242+
op = ScalarLoop(init=[x0], update=[x], until=until)
243+
fn = function([n_steps, x0], op(n_steps, x0), mode=numba_mode)
244+
np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True])
245+
np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
246+
np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])
247+
np.testing.assert_allclose(fn(n_steps=0, x0=1), [1, False])

0 commit comments

Comments
 (0)