File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -613,10 +613,6 @@ def local_subtensor_make_vector(fgraph, node):
613
613
something more general for constant ``*Subtensor*`` graphs (or perhaps
614
614
include this kind of work in the constant folding).
615
615
"""
616
-
617
- if not isinstance (node .op , Subtensor | AdvancedSubtensor1 ):
618
- return False
619
-
620
616
x = node .inputs [0 ]
621
617
622
618
if not (x .owner and isinstance (x .owner .op , MakeVector )):
@@ -666,7 +662,11 @@ def local_subtensor_make_vector(fgraph, node):
666
662
const_slice = get_constant_idx (
667
663
node .op .idx_list , node .inputs , allow_partial = False
668
664
)[0 ]
669
- ret = make_vector_op (* x .owner .inputs [const_slice ])
665
+ sliced_inputs = x .owner .inputs [const_slice ]
666
+ if len (sliced_inputs ) == 1 :
667
+ ret = expand_dims (sliced_inputs [0 ], axis = 0 )
668
+ else :
669
+ ret = make_vector_op (* sliced_inputs )
670
670
copy_stack_trace (node .outputs , ret )
671
671
return [ret ]
672
672
except NotScalarConstantError :
You can’t perform that action at this time.
0 commit comments