Skip to content

Commit efc9d69

Browse files
committed
local_subtensor_make_vector: don't return make_vector when slice keeps only one item
1 parent c2ede26 commit efc9d69

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -613,10 +613,6 @@ def local_subtensor_make_vector(fgraph, node):
613613
something more general for constant ``*Subtensor*`` graphs (or perhaps
614614
include this kind of work in the constant folding).
615615
"""
616-
617-
if not isinstance(node.op, Subtensor | AdvancedSubtensor1):
618-
return False
619-
620616
x = node.inputs[0]
621617

622618
if not (x.owner and isinstance(x.owner.op, MakeVector)):
@@ -666,7 +662,11 @@ def local_subtensor_make_vector(fgraph, node):
666662
const_slice = get_constant_idx(
667663
node.op.idx_list, node.inputs, allow_partial=False
668664
)[0]
669-
ret = make_vector_op(*x.owner.inputs[const_slice])
665+
sliced_inputs = x.owner.inputs[const_slice]
666+
if len(sliced_inputs) == 1:
667+
ret = expand_dims(sliced_inputs[0], axis=0)
668+
else:
669+
ret = make_vector_op(*sliced_inputs)
670670
copy_stack_trace(node.outputs, ret)
671671
return [ret]
672672
except NotScalarConstantError:

0 commit comments

Comments
 (0)