Skip to content

Commit d3e389f

Browse files
committed
Cleanup Scan symbolic buffer size graph
Graph was being broken by Scalar/Tensor conversions that prevented fusion
1 parent 405106f commit d3e389f

File tree

3 files changed

+27
-2
lines changed

3 files changed

+27
-2
lines changed

pytensor/scan/rewriting.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from pytensor.graph.type import HasShape
4242
from pytensor.graph.utils import InconsistencyError
4343
from pytensor.raise_op import Assert
44-
from pytensor.scalar import ScalarConstant
44+
from pytensor.scalar import ScalarConstant, ScalarVariable
4545
from pytensor.scan.op import Scan, ScanInfo
4646
from pytensor.scan.utils import (
4747
ScanArgs,
@@ -54,6 +54,7 @@
5454
from pytensor.tensor.basic import (
5555
Alloc,
5656
AllocEmpty,
57+
ScalarFromTensor,
5758
get_scalar_constant_value,
5859
)
5960
from pytensor.tensor.elemwise import DimShuffle, Elemwise
@@ -1300,7 +1301,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
13001301
# or not
13011302
flag_store = False
13021303

1303-
# 2.2 Loop over the clients
1304+
# 2.2 Loop over the clients to figure out how many steps we actually need to do in the Scan
13041305
for i, out in enumerate(node.outputs[:c_outs]):
13051306
# look at all its clients
13061307
slices[i] = []
@@ -1343,6 +1344,14 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
13431344
except KeyError:
13441345
length = out.shape[0]
13451346
cf_slice = get_canonical_form_slice(this_slice[0], length)
1347+
1348+
if (
1349+
isinstance(cf_slice[0], ScalarVariable)
1350+
and cf_slice[0].owner is not None
1351+
and isinstance(cf_slice[0].owner.op, ScalarFromTensor)
1352+
):
1353+
cf_slice = (cf_slice[0].owner.inputs[0], cf_slice[1])
1354+
13461355
slices[i] += [(cf_slice, this_slice)]
13471356

13481357
if isinstance(this_slice[0], slice) and this_slice[0].stop is None:

pytensor/tensor/rewriting/subtensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
get_slice_elements,
8585
inc_subtensor,
8686
indices_from_subtensor,
87+
undo_scalarization,
8788
)
8889
from pytensor.tensor.type import TensorType, integer_dtypes
8990
from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType
@@ -1136,6 +1137,7 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
11361137
# We are in the more complex case when we do not actually know
11371138
# if the first slice was in reverse or not.
11381139
# in case it was not in reverse:
1140+
sl2 = undo_scalarization(sl2)
11391141
p_val = sl1.start + sl2 * sl1.step
11401142
# case it was in reverse we need to realize that we do not want
11411143
# the k-th element from sl.start but the k-th element from

pytensor/tensor/subtensor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
nonzero,
3636
scalar_from_tensor,
3737
)
38+
from pytensor.tensor.basic import (
39+
constant as tensor_constant,
40+
)
3841
from pytensor.tensor.blockwise import vectorize_node_fallback
3942
from pytensor.tensor.elemwise import DimShuffle
4043
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
@@ -266,6 +269,15 @@ def get_canonical_form_slice(
266269
) -> tuple[ScalarVariable, int]: ...
267270

268271

272+
def undo_scalarization(x):
273+
if isinstance(x, ScalarVariable):
274+
if isinstance(x, ScalarConstant):
275+
return tensor_constant(x.data, dtype=x.dtype)
276+
elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor):
277+
return x.owner.inputs[0]
278+
return x
279+
280+
269281
def get_canonical_form_slice(
270282
theslice: slice | int | np.integer | ScalarVariable | TensorVariable,
271283
length: int | np.integer | ScalarVariable | TensorVariable,
@@ -301,6 +313,7 @@ def get_canonical_form_slice(
301313
if isinstance(theslice, int | np.integer | ScalarVariable) or (
302314
isinstance(theslice, TensorVariable) and theslice.ndim == 0
303315
):
316+
theslice = undo_scalarization(theslice)
304317
cano = switch(lt(theslice, 0), (theslice + length), theslice)
305318
return scalar_from_tensor(cano), 1
306319
raise ValueError(f"Slice {theslice} is not a supported slice type.")
@@ -381,6 +394,7 @@ def analyze(x):
381394
elif is_stop_length:
382395
# start:length:1
383396
if is_start_constant and start >= 0:
397+
length = undo_scalarization(length)
384398
return slice(switch(lt(start, length), start, length), length, 1), 1
385399
start_plus_len = start + length
386400
start = switch(

0 commit comments

Comments
 (0)