Skip to content

Commit 790b46f

Browse files
aseyboldtricardoV94
authored andcommitted
enh: Improve static shape of subtensor
1 parent 1d9fa84 commit 790b46f

File tree

2 files changed

+66
-22
lines changed

2 files changed

+66
-22
lines changed

pytensor/tensor/subtensor.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,16 @@ def analyze(x):
221221
step, is_step_constant = analyze(theslice.step)
222222
length, is_length_constant = analyze(length)
223223

224+
if (
225+
is_start_constant
226+
and is_stop_constant
227+
and is_step_constant
228+
and is_length_constant
229+
):
230+
_start, _stop, _step = slice(start, stop, step).indices(length)
231+
if _start <= _stop and _step >= 1:
232+
return slice(_start, _stop, _step), 1
233+
224234
if step is None:
225235
step = 1
226236
is_step_constant = True
@@ -722,32 +732,51 @@ def make_node(self, x, *inputs):
722732
f"Incompatible types for Subtensor template. Expected {input.type}, got {expected_type}."
723733
)
724734

725-
# infer the broadcasting pattern
726-
padded = get_constant_idx(
727-
self.idx_list, (None,) + inputs, allow_partial=True
728-
) + [slice(None, None, None)] * (x.type.ndim - len(idx_list))
735+
padded = [
736+
*get_idx_list((None,) + inputs, self.idx_list),
737+
*[slice(None, None, None)] * (x.type.ndim - len(idx_list)),
738+
]
729739

730740
out_shape = []
731-
for i, (p, s) in enumerate(zip(padded, x.type.shape)):
732-
if isinstance(p, slice):
733-
if s == 1:
734-
start = p.start
735-
try:
736-
start = get_underlying_scalar_constant_value(start)
737-
except NotScalarConstantError:
738-
pass
739-
if start is None or start == 0:
740-
start = p.start
741-
if start is None:
742-
start = 0
743-
if p.stop is None or (
744-
isinstance(p.stop, (int, np.integer, np.ndarray))
745-
and p.stop > start
746-
):
747-
out_shape.append(1)
748-
continue
749741

742+
def extract_const(value):
743+
if value is None:
744+
return value, True
745+
try:
746+
value = get_underlying_scalar_constant_value(value)
747+
return value, True
748+
except NotScalarConstantError:
749+
return value, False
750+
751+
for the_slice, length in zip(padded, x.type.shape):
752+
if not isinstance(the_slice, slice):
753+
continue
754+
755+
if length is None:
750756
out_shape.append(None)
757+
continue
758+
759+
start = the_slice.start
760+
stop = the_slice.stop
761+
step = the_slice.step
762+
763+
is_slice_const = True
764+
765+
start, is_const = extract_const(start)
766+
is_slice_const = is_slice_const and is_const
767+
768+
stop, is_const = extract_const(stop)
769+
is_slice_const = is_slice_const and is_const
770+
771+
step, is_const = extract_const(step)
772+
is_slice_const = is_slice_const and is_const
773+
774+
if not is_slice_const:
775+
out_shape.append(None)
776+
continue
777+
778+
slice_length = len(range(*slice(start, stop, step).indices(length)))
779+
out_shape.append(slice_length)
751780

752781
return Apply(
753782
self,

tests/tensor/test_subtensor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2693,3 +2693,18 @@ def test_index_vars_to_types():
26932693
assert isinstance(x.type, scal.ScalarType)
26942694
res = index_vars_to_types(x)
26952695
assert res == x.type
2696+
2697+
2698+
@pytest.mark.parametrize(
2699+
"x_shape, indices, expected",
2700+
[
2701+
[(None,), (slice(None, None),), (None,)],
2702+
[(13,), (slice(None, 100),), (13,)],
2703+
[(13,), (slice(-1, None),), (1,)],
2704+
[(7, 13), (slice(None, None, 2), slice(-1, 1, -1)), (4, 11)],
2705+
],
2706+
)
2707+
def test_static_shapes(x_shape, indices, expected):
2708+
x = at.tensor(dtype="float64", shape=x_shape)
2709+
y = x[indices]
2710+
assert y.type.shape == expected

0 commit comments

Comments
 (0)