Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,7 +1620,7 @@ def local_reduce_join(fgraph, node):
if not inp.type.broadcastable[join_axis]:
return None
# Most times inputs to join have an expand_dims, we eagerly clean up those here
new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis))
new_input = apply_local_dimshuffle_lift(fgraph, inp.squeeze(join_axis))
new_inputs.append(new_input)

ret = Elemwise(node.op.scalar_op)(*new_inputs)
Expand Down
45 changes: 28 additions & 17 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,14 +342,18 @@ def local_subtensor_of_dot(fgraph, node):
@node_rewriter([Subtensor])
def local_useless_slice(fgraph, node):
"""
Remove Subtensor of the form:
Remove useless slice(None) of the form:
1. X[0, :] -> X[0]
2. X[:] -> X

Also, rewrite Subtensor of the form:
Also, canonicalize slices of the form:
X[0:7:1] -> X[None:None:None]
where X is a vector of length 7

And:
X[-1:-8:-1] -> X[::-1]
where x is a vector of length 7

"""
idxs = get_idx_list(node.inputs, node.op.idx_list)
x = node.inputs[0]
Expand All @@ -368,32 +372,40 @@ def local_useless_slice(fgraph, node):
if s == slice(None):
continue

step = s.step

if step is None:
positive_step = True
elif isinstance(step, Constant):
step_value = step.data
positive_step = step.data > 0
if step_value == 1:
change_flag = True
step = None
else:
# We can only canonicalize start and stop if we know the sign of step
last_useful_idx = dim
continue

start = s.start
stop = s.stop
step = s.step
if (
start is not None
and extract_constant(start, only_process_constants=True) == 0
):

if start is not None and extract_constant(
start, only_process_constants=True
) == (0 if positive_step else -1):
change_flag = True
start = None

if (
stop is not None
and x.type.shape[dim] is not None
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
and extract_constant(stop, only_process_constants=True)
== (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1)
):
change_flag = True
stop = None

if (
step is not None
and extract_constant(step, only_process_constants=True) == 1
):
change_flag = True
step = None

if not (start is None and stop is None and step is None):
if start is not None or stop is not None or step is not None:
last_useful_idx = dim

new_idxs[dim] = slice(start, stop, step)
Expand All @@ -402,7 +414,6 @@ def local_useless_slice(fgraph, node):
out = x[tuple(new_idxs[: last_useful_idx + 1])]
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, out)

return [out]


Expand Down
19 changes: 19 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
local_mul_canonizer,
local_mul_switch_sink,
local_reduce_chain,
local_reduce_join,
local_sum_prod_of_mul_or_div,
mul_canonizer,
parse_mul_tree,
Expand Down Expand Up @@ -3415,6 +3416,24 @@ def test_not_supported_unequal_shapes(self):
f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0)
)

def test_non_ds_inputs(self):
"""Make sure rewrite works when inputs to join are not the usual DimShuffle.

Sum{axis=1} [id A] <Vector(float64, shape=(3,))>
└─ Join [id B] <Matrix(float64, shape=(3, 3))>
├─ 1 [id C] <Scalar(int8, shape=())>
├─ ExpandDims{axis=1} [id D] <Matrix(float64, shape=(3, 1))>
├─ Sub [id E] <Matrix(float64, shape=(3, 1))>
└─ Sub [id F] <Matrix(float64, shape=(3, 1))>
"""
x = vector("x")
out = join(0, exp(x[None]), log(x[None])).sum(axis=0)

fg = FunctionGraph([x], [out], clone=False)
[rewritten_out] = local_reduce_join.transform(fg, out.owner)
expected_out = add(exp(x), log(x))
assert equal_computations([rewritten_out], [expected_out])


def test_local_useless_adds():
default_mode = get_default_mode()
Expand Down
100 changes: 66 additions & 34 deletions tests/tensor/rewriting/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2404,42 +2404,74 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
np.testing.assert_allclose(fn(test_x, test_y), expected_out)


def test_slice_canonicalize():
rng = np.random.default_rng(43)
x = tensor(shape=(3, 5, None, 9))
test_x = rng.normal(size=(3, 5, 8, 9))
# Test case 1
y = x[0:None, 0:5, 0:7, 0:9:1]
f = pytensor.function([x], y, allow_input_downcast=True)

# Get the DeepCopy input and assert that the Op is a DeepCopy
test_y = f.maker.fgraph.outputs[0].owner.inputs[0]
assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp)

expected_y = x[None:None:None, None:None:None, None:7:None]

assert equal_computations([test_y], [expected_y])

np.testing.assert_allclose(
f(test_x),
test_x[
0:None, 0:5, 0:7, 0:9:1
], # Use the unoptimized slice to make sure our rewrite logic is correct
)
class TestUselessSlice:
def test_positive_step(self):
# When steps are positive, default start and end are `0` and `len(dim)`
x = tensor(shape=(3, 5, None, 9), dtype="float64")
test_x = np.random.normal(size=(3, 5, 8, 9))

y = x[0:3:1, 1:5:2, 0:7:1, 0:9:1]
f = pytensor.function([x], y)

# Get the DeepCopy input and assert that the Op is a DeepCopy
deep_copy_node = f.maker.fgraph.outputs[0].owner
assert isinstance(deep_copy_node.op, DeepCopyOp)

rewritten_y = deep_copy_node.inputs[0]
expected_y = x[None:None:None, 1:None:2, None:7:None]
assert equal_computations([rewritten_y], [expected_y])

np.testing.assert_allclose(
f(test_x),
# Use the unoptimized slice to make sure our rewrite logic is correct
test_x[0:3:1, 1:5:2, 0:7:1, 0:9:1],
)

# Test case 2
y1 = x[0:-1, 0:5, 0:7, 0:-1:-1]
f1 = pytensor.function([x], y1, allow_input_downcast=True)
def test_negative_step(self):
# When steps are negative, default start and end are `-1` and `-len(dim) - 1`
x = tensor(shape=(3, 5, None, 9), dtype="float64")
test_x = np.random.normal(size=(3, 5, 8, 9))

# Get the DeepCopy input and assert that the Op is a DeepCopy
test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0]
assert isinstance(f1.maker.fgraph.outputs[0].owner.op, DeepCopyOp)
y = x[-1:-4:-1, 0:5:-2, -1:-9:-1, 0:9:None]
f = pytensor.function([x], y)

expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]
# Get the DeepCopy input and assert that the Op is a DeepCopy
deep_copy_node = f.maker.fgraph.outputs[0].owner
assert isinstance(deep_copy_node.op, DeepCopyOp)

assert equal_computations([test_y1], [expected_y1])
rewritten_y = deep_copy_node.inputs[0]
expected_y = x[None:None:-1, 0:5:-2, None:-9:-1]
assert equal_computations([rewritten_y], [expected_y])

np.testing.assert_allclose(
f1(test_x),
test_x[0:-1, 0:5, 0:7, 0:-1:-1],
)
np.testing.assert_allclose(
f(test_x),
test_x[-1:-4:-1, 0:5:-2, -1:-9:-1, 0:9:None],
)

def test_unknown_step(self):
# If step isn't known, we can't canonicalize start and stop points
step = pt.scalar("step", dtype=int)
x = tensor(shape=(3, 5, None), dtype="float64")
test_x = np.random.normal(size=(3, 5, 7))

y = x[0:3:step, -1:-6:-step, ::]
# Need this rewrite when `FAST_COMPILE` otherwise step = -1 * step instead of neg(step)
mode = get_default_mode().including("local_mul_specialize")
f = pytensor.function([x, step], y, mode=mode)

# Get the DeepCopy input and assert that the Op is a DeepCopy
deep_copy_node = f.maker.fgraph.outputs[0].owner
assert isinstance(deep_copy_node.op, DeepCopyOp)

rewritten_y = deep_copy_node.inputs[0]
expected_y = x[0:3:step, -1:-6:-step]
assert equal_computations([rewritten_y], [expected_y])

np.testing.assert_allclose(
f(test_x, 1),
test_x[0:3:1, -1:-6:-1, ::],
)
np.testing.assert_allclose(
f(test_x, -2),
test_x[0:3:-2, -1:-6:2, ::],
)
Loading