Skip to content

Commit 63c513e

Browse files
committed
Simplify Repeat Op to only work with specific axis and vector repeats
1 parent 3dcf1fb commit 63c513e

File tree

5 files changed

+142
-187
lines changed

5 files changed

+142
-187
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def numba_typify(data, dtype=None, **kwargs):
220220
return data
221221

222222

223-
def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
223+
def generate_fallback_impl(op, node, storage_map=None, **kwargs):
224224
"""Create a Numba compatible function from a Pytensor `Op`."""
225225

226226
warnings.warn(

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66

77
from pytensor.graph import Apply
88
from pytensor.link.numba.dispatch import basic as numba_basic
9-
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
9+
from pytensor.link.numba.dispatch.basic import (
10+
generate_fallback_impl,
11+
get_numba_type,
12+
numba_funcify,
13+
)
1014
from pytensor.raise_op import CheckAndRaise
1115
from pytensor.tensor import TensorVariable
1216
from pytensor.tensor.extra_ops import (
@@ -200,45 +204,19 @@ def ravelmultiindex(*inp):
200204
@numba_funcify.register(Repeat)
201205
def numba_funcify_Repeat(op, node, **kwargs):
202206
axis = op.axis
207+
a, _ = node.inputs
203208

204-
use_python = False
205-
206-
if axis is not None:
207-
use_python = True
208-
209-
if use_python:
210-
warnings.warn(
211-
(
212-
"Numba will use object mode to allow the "
213-
"`axis` argument to `numpy.repeat`."
214-
),
215-
UserWarning,
216-
)
217-
218-
ret_sig = get_numba_type(node.outputs[0].type)
209+
# Numba only supports axis=None, which in our case is when axis is 0 and the input is a vector
210+
if axis == 0 and a.type.ndim == 1:
219211

220-
@numba_basic.numba_njit
212+
@numba_basic.numba_njit(inline="always")
221213
def repeatop(x, repeats):
222-
with numba.objmode(ret=ret_sig):
223-
ret = np.repeat(x, repeats, axis)
224-
return ret
214+
return np.repeat(x, repeats)
225215

226-
else:
227-
repeats_ndim = node.inputs[1].ndim
216+
return repeatop
228217

229-
if repeats_ndim == 0:
230-
231-
@numba_basic.numba_njit(inline="always")
232-
def repeatop(x, repeats):
233-
return np.repeat(x, repeats.item())
234-
235-
else:
236-
237-
@numba_basic.numba_njit(inline="always")
238-
def repeatop(x, repeats):
239-
return np.repeat(x, repeats)
240-
241-
return repeatop
218+
else:
219+
return generate_fallback_impl(op, node)
242220

243221

244222
@numba_funcify.register(Unique)

pytensor/tensor/extra_ops.py

Lines changed: 79 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -660,53 +660,72 @@ class Repeat(Op):
660660

661661
__props__ = ("axis",)
662662

663-
def __init__(self, axis: int | None = None):
664-
if axis is not None:
665-
if not isinstance(axis, int) or axis < 0:
663+
def __init__(self, axis: int):
664+
if isinstance(axis, int):
665+
if axis < 0:
666666
raise ValueError(
667-
f"Repeat only accepts positive integer axis or None, got {axis}"
667+
f"Repeat Op only accepts positive integer axis, got {axis}. "
668+
"Use the helper `pt.repeat` to handle negative axis."
668669
)
670+
elif axis is None:
671+
raise ValueError(
672+
"Repeat Op only accepts positive integer axis. "
673+
"Use the helper `pt.repeat` to handle axis=None."
674+
)
675+
else:
676+
raise TypeError(
677+
f"Invalid type for axis {axis}, expected int got {type(axis)}"
678+
)
679+
669680
self.axis = axis
670681

671682
def make_node(self, x, repeats):
672683
x = ptb.as_tensor_variable(x)
673684
repeats = ptb.as_tensor_variable(repeats, dtype="int64")
674685

675-
if repeats.dtype not in integer_dtypes:
676-
raise TypeError("repeats.dtype must be an integer.")
686+
if repeats.type.ndim != 1:
687+
if repeats.type.ndim == 0:
688+
raise ValueError(
689+
f"repeats {repeats} must have 1 dimension, got 0. Use the helper `pt.repeat` to handle scalar repeats."
690+
)
691+
else:
692+
raise ValueError(
693+
f"repeats {repeats} must have 1 dimension, got {repeats.type.ndim}"
694+
)
695+
696+
if repeats.type.dtype not in integer_dtypes:
697+
raise TypeError(
698+
f"repeats {repeats} dtype must be an integer, got {repeats.type.dtype}."
699+
)
677700

678701
# Some dtypes are not supported by numpy's implementation of repeat.
679702
# Until another one is available, we should fail at graph construction
680703
# time, not wait for execution.
681-
ptr_bitwidth = LOCAL_BITWIDTH
682-
if ptr_bitwidth == 64:
683-
numpy_unsupported_dtypes = ("uint64",)
684-
if ptr_bitwidth == 32:
685-
numpy_unsupported_dtypes = ("uint32", "int64", "uint64")
686-
687-
if repeats.dtype in numpy_unsupported_dtypes:
704+
numpy_unsupported_dtypes = (
705+
("uint64",) if LOCAL_BITWIDTH == 64 else ("uint64", "uint32", "int64")
706+
)
707+
if repeats.type.dtype in numpy_unsupported_dtypes:
688708
raise TypeError(
689-
(
690-
f"dtypes {numpy_unsupported_dtypes!s} are not supported by numpy.repeat "
691-
"for the 'repeats' parameter, "
692-
),
693-
repeats.dtype,
709+
f"repeats {repeats} dtype {repeats.type.dtype} are not supported by numpy.repeat"
694710
)
695711

696-
if self.axis is None:
697-
out_shape = [None]
698-
else:
712+
shape = list(x.type.shape)
713+
axis_input_dim_length = shape[self.axis]
714+
axis_output_dim_length = None
715+
716+
if axis_input_dim_length is not None:
717+
# If we have a static dim and constant repeats we can infer the length of the output dim
718+
# Right now we only support homogenous constant repeats
699719
try:
700-
const_reps = ptb.get_scalar_constant_value(repeats)
720+
const_reps = ptb.get_underlying_scalar_constant_value(repeats)
701721
except NotScalarConstantError:
702-
const_reps = None
703-
if const_reps == 1:
704-
out_shape = x.type.shape
722+
pass
705723
else:
706-
out_shape = list(x.type.shape)
707-
out_shape[self.axis] = None
724+
axis_output_dim_length = int(const_reps * axis_input_dim_length)
725+
726+
shape[self.axis] = axis_output_dim_length
708727

709-
out_type = TensorType(x.dtype, shape=out_shape)
728+
out_type = TensorType(x.dtype, shape=shape)
710729
return Apply(self, [x, repeats], [out_type()])
711730

712731
def perform(self, node, inputs, output_storage):
@@ -720,36 +739,19 @@ def grad(self, inputs, gout):
720739
(x, repeats) = inputs
721740
(gz,) = gout
722741
axis = self.axis
723-
if repeats.ndim == 0:
724-
# When axis is a scalar (same number of reps for all elements),
725-
# We can split the repetitions into their own axis with reshape and sum them back
726-
# to the original element location
727-
sum_axis = x.ndim if axis is None else axis + 1
728-
shape = list(x.shape)
729-
shape.insert(sum_axis, repeats)
730-
gx = gz.reshape(shape).sum(axis=sum_axis)
731-
732-
elif repeats.ndim == 1:
733-
# To sum the gradients that belong to the same repeated x,
734-
# We create a repeated eye and dot product it with the gradient.
735-
axis_size = x.size if axis is None else x.shape[axis]
736-
repeated_eye = repeat(
737-
ptb.eye(axis_size), repeats, axis=0
738-
) # A sparse repeat would be neat
739-
740-
if axis is None:
741-
gx = gz @ repeated_eye
742-
# Undo the ravelling when axis=None
743-
gx = gx.reshape(x.shape)
744-
else:
745-
# Place gradient axis at end for dot product
746-
gx = ptb.moveaxis(gz, axis, -1)
747-
gx = gx @ repeated_eye
748-
# Place gradient back into the correct axis
749-
gx = ptb.moveaxis(gx, -1, axis)
750742

751-
else:
752-
raise ValueError()
743+
# To sum the gradients that belong to the same repeated x,
744+
# We create a repeated eye and dot product it with the gradient.
745+
axis_size = x.shape[axis]
746+
repeated_eye = repeat(
747+
ptb.eye(axis_size), repeats, axis=0
748+
) # A sparse repeat would be neat
749+
750+
# Place gradient axis at end for dot product
751+
gx = ptb.moveaxis(gz, axis, -1)
752+
gx = gx @ repeated_eye
753+
# Place gradient back into the correct axis
754+
gx = ptb.moveaxis(gx, -1, axis)
753755

754756
return [gx, DisconnectedType()()]
755757

@@ -763,22 +765,8 @@ def infer_shape(self, fgraph, node, ins_shapes):
763765
dtype = None
764766
if repeats.dtype in ("uint8", "uint16", "uint32"):
765767
dtype = "int64"
766-
if axis is None:
767-
if repeats.ndim == 0:
768-
if len(i0_shapes) == 0:
769-
out_shape = [repeats]
770-
else:
771-
res = 1
772-
for d in i0_shapes:
773-
res = res * d
774-
out_shape = (res * repeats,)
775-
else:
776-
out_shape = [pt_sum(repeats, dtype=dtype)]
777-
else:
778-
if repeats.ndim == 0:
779-
out_shape[axis] = out_shape[axis] * repeats
780-
else:
781-
out_shape[axis] = pt_sum(repeats, dtype=dtype)
768+
769+
out_shape[axis] = pt_sum(repeats, dtype=dtype)
782770
return [out_shape]
783771

784772

@@ -843,48 +831,42 @@ def repeat(
843831
"""
844832
a = ptb.as_tensor_variable(a)
845833

846-
if axis is not None:
834+
if axis is None:
835+
axis = 0
836+
a = a.flatten()
837+
else:
847838
axis = normalize_axis_index(axis, a.ndim)
848839

849840
repeats = ptb.as_tensor_variable(repeats, dtype=np.int64)
850841

851842
if repeats.ndim > 1:
852843
raise ValueError("The dimension of repeats should not exceed 1.")
853844

854-
if repeats.ndim == 1 and not repeats.broadcastable[0]:
845+
if repeats.type.broadcastable == (True,):
846+
# This behaves the same as scalar repeat
847+
repeats = repeats.squeeze()
848+
849+
if repeats.ndim == 1:
855850
# We only use the Repeat Op for vector repeats
856851
return Repeat(axis=axis)(a, repeats)
857852
else:
858-
if repeats.ndim == 1:
859-
repeats = repeats[0]
860-
861853
if a.dtype == "uint64":
862854
# Multiplying int64 (shape) by uint64 (repeats) yields a float64
863855
# Which is not valid for the `reshape` operation at the end
864856
raise TypeError("repeat doesn't support dtype uint64")
865857

866-
if axis is None:
867-
axis = 0
868-
a = a.flatten()
869-
870-
repeat_shape = list(a.shape)
858+
# Scalar repeat, we implement this with canonical Ops broadcast + reshape
859+
a_shape = a.shape
871860

872-
# alloc_shape is the shape of the intermediate tensor which has
873-
# an additional dimension comparing to x. We use alloc to
874-
# allocate space for this intermediate tensor to replicate x
875-
# along that additional dimension.
876-
alloc_shape = repeat_shape[:]
877-
alloc_shape.insert(axis + 1, repeats)
861+
# Replicate a along a new axis (axis+1) repeats times
862+
broadcast_shape = list(a_shape)
863+
broadcast_shape.insert(axis + 1, repeats)
864+
broadcast_a = broadcast_to(ptb.expand_dims(a, axis + 1), broadcast_shape)
878865

879-
# repeat_shape is now the shape of output, where shape[axis] becomes
880-
# shape[axis]*repeats.
866+
# Reshape broadcast_a to the final shape, merging axis and axis+1
867+
repeat_shape = list(a_shape)
881868
repeat_shape[axis] = repeat_shape[axis] * repeats
882-
883-
# After the original tensor is duplicated along the additional
884-
# dimension, we reshape it to the expected output shape
885-
return ptb.alloc(ptb.expand_dims(a, axis + 1), *alloc_shape).reshape(
886-
repeat_shape
887-
)
869+
return broadcast_a.reshape(repeat_shape)
888870

889871

890872
class Bartlett(Op):

tests/link/numba/test_extra_ops.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -212,27 +212,15 @@ def test_RavelMultiIndex(arr, shape, mode, order, exc):
212212
@pytest.mark.parametrize(
213213
"x, repeats, axis, exc",
214214
[
215-
(
216-
(pt.lscalar(), np.array(1, dtype="int64")),
217-
(pt.lscalar(), np.array(0, dtype="int64")),
218-
None,
219-
None,
220-
),
221-
(
222-
(pt.lmatrix(), np.zeros((2, 2), dtype="int64")),
223-
(pt.lscalar(), np.array(1, dtype="int64")),
224-
None,
225-
None,
226-
),
227215
(
228216
(pt.lvector(), np.arange(2, dtype="int64")),
229-
(pt.lvector(), np.array([1, 1], dtype="int64")),
230-
None,
217+
(pt.lvector(), np.array([1, 3], dtype="int64")),
218+
0,
231219
None,
232220
),
233221
(
234222
(pt.lmatrix(), np.zeros((2, 2), dtype="int64")),
235-
(pt.lscalar(), np.array(1, dtype="int64")),
223+
(pt.lvector(), np.array([1, 3], dtype="int64")),
236224
0,
237225
UserWarning,
238226
),

0 commit comments

Comments
 (0)