Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pytensor/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,7 @@ def join(axis, *tensors):
def numba_funcify_Split(op, **kwargs):
@numba_basic.numba_njit
def split(tensor, axis, indices):
# Work around for https://github.com/numba/numba/issues/8257
axis = axis % tensor.ndim
axis = numba_basic.to_scalar(axis)
return np.split(tensor, np.cumsum(indices)[:-1], axis=axis)
return np.split(tensor, np.cumsum(indices)[:-1], axis=axis.item())

return split

Expand Down
22 changes: 20 additions & 2 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2201,8 +2201,26 @@ def make_node(self, x, axis, splits):
raise TypeError("`axis` parameter must be an integer scalar")

inputs = [x, axis, splits]
out_type = TensorType(dtype=x.dtype, shape=(None,) * x.type.ndim)
outputs = [out_type() for i in range(self.len_splits)]

x_dtype = x.type.dtype
if isinstance(axis, Constant):
# In this case we can preserve more static shape info
static_axis = axis.data.item()
outputs = []
x_static_shape = list(x.type.shape)
for i in range(self.len_splits):
try:
static_split_size = int(get_scalar_constant_value(splits[i]))
except NotScalarConstantError:
static_split_size = None
static_out_shape = x_static_shape.copy()
static_out_shape[static_axis] = static_split_size
outputs.append(tensor(shape=tuple(static_out_shape), dtype=x_dtype))
else:
outputs = [
tensor(shape=(None,) * x.type.ndim, dtype=x_dtype)
for i in range(self.len_splits)
]

return Apply(self, inputs, outputs)

Expand Down
34 changes: 34 additions & 0 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Join,
MakeVector,
ScalarFromTensor,
Split,
TensorFromScalar,
alloc,
as_tensor,
Expand Down Expand Up @@ -616,6 +617,39 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
return [node.inputs[0].dimshuffle(tuple(remain_dim))]


@register_specialize("shape_unsafe")
@node_rewriter(tracks=[Split])
def split_to_subtensor(fgraph, node):
"""Rewrite split(2)(x, 0) -> (x[:split_sizes[0]], x[split_sizes[0]:).

This allows lifting the underlying split close to the inputs, and increases fusion opportunities.
It automatically handles unused split outputs as well

It only works for constant axis
"""
x, axis, split_sizes = node.inputs

n_splits = node.op.len_splits
if n_splits <= 1:
return [x]

if not isinstance(axis, Constant):
return None

empty_slices = (slice(None),) * int(axis.data)
ys = []

end = split_sizes[0]
ys.append(x[(*empty_slices, slice(None, end))])
prev_start = end
for i in range(1, n_splits - 1):
end = prev_start + split_sizes[i]
ys.append(x[(*empty_slices, slice(prev_start, end))])
prev_start = end
ys.append(x[(*empty_slices, slice(prev_start, None))])
return ys


@register_infer_shape
@register_useless
@register_canonicalize
Expand Down
Loading