Skip to content

Commit 5f374db

Browse files
committed
Fix numba AdvancedIncSubtensor1 with broadcasted values
1 parent 1e96b89 commit 5f374db

File tree

2 files changed

+90
-24
lines changed

2 files changed

+90
-24
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -604,36 +604,70 @@ def numba_funcify_IncSubtensor(op, node, **kwargs):
604604
return numba_njit(incsubtensor_fn, boundscheck=True)
605605

606606

607-
@numba_njit(boundscheck=True)
608-
def advancedincsubtensor1_inplace_set(x, vals, idxs):
609-
for idx, val in zip(idxs, vals):
610-
x[idx] = val
611-
return x
612-
613-
614-
@numba_njit(boundscheck=True)
615-
def advancedincsubtensor1_inplace_inc(x, vals, idxs):
616-
for idx, val in zip(idxs, vals):
617-
x[idx] += val
618-
return x
619-
620-
621607
@numba_funcify.register(AdvancedIncSubtensor1)
622608
def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
623609
inplace = op.inplace
624610
set_instead_of_inc = op.set_instead_of_inc
611+
x, vals, idxs = node.inputs
612+
# TODO: Add explicit expand_dims in make_node so we don't need to worry about this here
613+
broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
625614

626615
if set_instead_of_inc:
627-
advancedincsubtensor1_inplace = global_numba_func(
628-
advancedincsubtensor1_inplace_set
629-
)
616+
if broadcast:
617+
618+
@numba_njit(boundscheck=True)
619+
def advancedincsubtensor1_inplace(x, val, idxs):
620+
if val.ndim == x.ndim:
621+
core_val = val[0]
622+
elif val.ndim == 0:
623+
# Workaround for https://github.com/numba/numba/issues/9573
624+
core_val = val.item()
625+
else:
626+
core_val = val
627+
628+
for idx in idxs:
629+
x[idx] = core_val
630+
return x
631+
632+
else:
633+
634+
@numba_njit(boundscheck=True)
635+
def advancedincsubtensor1_inplace(x, vals, idxs):
636+
if not len(idxs) == len(vals):
637+
raise ValueError("The number of indices and values must match.")
638+
for idx, val in zip(idxs, vals):
639+
x[idx] = val
640+
return x
630641
else:
631-
advancedincsubtensor1_inplace = global_numba_func(
632-
advancedincsubtensor1_inplace_inc
633-
)
642+
if broadcast:
643+
644+
@numba_njit(boundscheck=True)
645+
def advancedincsubtensor1_inplace(x, val, idxs):
646+
if val.ndim == x.ndim:
647+
core_val = val[0]
648+
elif val.ndim == 0:
649+
# Workaround for https://github.com/numba/numba/issues/9573
650+
core_val = val.item()
651+
else:
652+
core_val = val
653+
654+
for idx in idxs:
655+
x[idx] += core_val
656+
return x
657+
658+
else:
659+
660+
@numba_njit(boundscheck=True)
661+
def advancedincsubtensor1_inplace(x, vals, idxs):
662+
if not len(idxs) == len(vals):
663+
raise ValueError("The number of indices and values must match.")
664+
for idx, val in zip(idxs, vals):
665+
x[idx] += val
666+
return x
634667

635668
if inplace:
636-
return global_numba_func(advancedincsubtensor1_inplace)
669+
return advancedincsubtensor1_inplace
670+
637671
else:
638672

639673
@numba_njit

tests/link/numba/test_basic.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ def test_Subtensor(x, indices):
406406
"x, indices",
407407
[
408408
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)),
409+
(pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1],)),
409410
],
410411
)
411412
def test_AdvancedSubtensor1(x, indices):
@@ -498,6 +499,27 @@ def test_IncSubtensor(x, y, indices):
498499
pt.as_tensor(rng.poisson(size=(2, 4, 5))),
499500
([1, 1],),
500501
),
502+
# Broadcasting values
503+
(
504+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
505+
pt.as_tensor(rng.poisson(size=(1, 4, 5))),
506+
([0, 2, 0],),
507+
),
508+
(
509+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
510+
pt.as_tensor(rng.poisson(size=(5,))),
511+
([0, 2],),
512+
),
513+
(
514+
pt.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
515+
pt.as_tensor(rng.poisson(size=())),
516+
([2, 0],),
517+
),
518+
(
519+
pt.as_tensor(np.arange(5)),
520+
pt.as_tensor(rng.poisson(size=())),
521+
([2, 0],),
522+
),
501523
],
502524
)
503525
def test_AdvancedIncSubtensor1(x, y, indices):
@@ -511,11 +533,21 @@ def test_AdvancedIncSubtensor1(x, y, indices):
511533
out_fg = FunctionGraph([], [out_pt])
512534
compare_numba_and_py(out_fg, [])
513535

536+
# With symbolic inputs
514537
x_pt = x.type()
515-
out_pt = pt_subtensor.AdvancedIncSubtensor1(inplace=True)(x_pt, y, *indices)
538+
y_pt = y.type()
539+
540+
out_pt = pt_subtensor.AdvancedIncSubtensor1(inplace=True)(x_pt, y_pt, *indices)
516541
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1)
517-
out_fg = FunctionGraph([x_pt], [out_pt])
518-
compare_numba_and_py(out_fg, [x.data])
542+
out_fg = FunctionGraph([x_pt, y_pt], [out_pt])
543+
compare_numba_and_py(out_fg, [x.data, y.data])
544+
545+
out_pt = pt_subtensor.AdvancedIncSubtensor1(set_instead_of_inc=True, inplace=True)(
546+
x_pt, y_pt, *indices
547+
)
548+
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1)
549+
out_fg = FunctionGraph([x_pt, y_pt], [out_pt])
550+
compare_numba_and_py(out_fg, [x.data, y.data])
519551

520552

521553
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)