Skip to content

Commit 0b731c2

Browse files
authored
Rewrite concatenate([x, x]) as tile (#1714)
1 parent ee56826 commit 0b731c2

File tree

3 files changed

+124
-30
lines changed

3 files changed

+124
-30
lines changed

pytensor/tensor/rewriting/basic.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
register_infer_shape,
7878
switch,
7979
tensor_copy,
80+
tile,
8081
zeros,
8182
zeros_like,
8283
)
@@ -910,6 +911,53 @@ def local_join_make_vector(fgraph, node):
910911
return [ret]
911912

912913

914+
@register_canonicalize
915+
@node_rewriter([Join])
916+
def local_join_to_repeat(fgraph, node):
917+
"""Join(axis, x, x, x, ...) -> tile(x, reps)
918+
919+
When the same tensor is concatenated multiple times along an axis,
920+
replace with a single tile operation which is more efficient.
921+
922+
Examples
923+
--------
924+
join(0, x, x, x) -> tile(x, (3, 1, 1, ...))
925+
join(1, x, x) -> tile(x, (1, 2, 1, ...))
926+
"""
927+
# Extract axis and the tensors being joined
928+
axis, *tensors = node.inputs
929+
930+
# Optimization only applies when axis is constant
931+
if not isinstance(axis, Constant):
932+
return None
933+
934+
# Extract the Python integer from the constant
935+
axis_val = axis.data
936+
937+
# Need at least 2 tensors to consider optimization
938+
if len(tensors) <= 1:
939+
return
940+
941+
# Check if all tensors are identical
942+
if not all(t == tensors[0] for t in tensors[1:]):
943+
return
944+
945+
n_reps = len(tensors)
946+
first_tensor = tensors[0]
947+
ndim = first_tensor.ndim
948+
949+
# Build reps tuple to repeat only along the join axis
950+
# For shape (a, b, c) joining at axis 1: reps = (1, n_reps, 1)
951+
# This directly concatenates n_reps copies along axis_val
952+
reps = tuple(n_reps if i == axis_val else 1 for i in range(ndim))
953+
954+
result = tile(first_tensor, reps)
955+
956+
# Preserve debugging information
957+
copy_stack_trace(node.outputs[0], result)
958+
return [result]
959+
960+
913961
@register_specialize
914962
@register_canonicalize
915963
@register_useless

tests/tensor/rewriting/test_basic.py

Lines changed: 76 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,33 +1237,98 @@ def test_local_join_1():
12371237
assert len([n for n in e if isinstance(n.op, Join)]) == 0
12381238
assert f.maker.fgraph.outputs[0].dtype == config.floatX
12391239

1240-
# test we don't apply when their is 2 inputs
1241-
s = join(1, a, a)
1240+
# Test that join with 2 different inputs remains (not optimized away)
1241+
s = join(1, a, a[:, ::-1])
12421242
f = function([a], s, mode=rewrite_mode)
1243-
val = f([[1]])
1244-
assert np.all(val == [[1]])
1243+
val = f([[1, 2]])
1244+
assert np.all(val == [[1, 2, 2, 1]]) # joined along axis 1
12451245
e = f.maker.fgraph.toposort()
1246-
assert len([n for n in e if isinstance(n.op, Join)]) == 1
1246+
assert len([n for n in e if isinstance(n.op, Join)]) == 1 # join remains
12471247
assert f.maker.fgraph.outputs[0].dtype == config.floatX
12481248

12491249

1250+
def test_local_join_to_tile():
1251+
"""Join(axis, x, x, ...) is rewritten to tile(x, reps) with reps[axis] = k.
1252+
1253+
This optimization applies whenever we concatenate the *same* tensor multiple
1254+
times along a given axis. It replaces the Join/concatenate with a Tile op.
1255+
"""
1256+
1257+
# ---- Case 1: joining same vector along axis 0 ----
1258+
x = vector("x")
1259+
s = join(0, x, x, x) # (3n,)
1260+
f = function([x], s, mode=rewrite_mode)
1261+
1262+
test_val = np.array([1.0, 2.0], dtype=config.floatX)
1263+
result = f(test_val)
1264+
expected = np.array([1.0, 2.0, 1.0, 2.0, 1.0, 2.0], dtype=config.floatX)
1265+
assert np.allclose(result, expected)
1266+
1267+
# Join should be optimized away
1268+
ops = f.maker.fgraph.toposort()
1269+
assert not any(isinstance(n.op, Join) for n in ops)
1270+
1271+
# ---- Case 2: joining same matrix along axis 0 ----
1272+
a = matrix("a")
1273+
s = join(0, a, a) # (2m, n)
1274+
f = function([a], s, mode=rewrite_mode)
1275+
1276+
test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
1277+
result = f(test_mat)
1278+
expected = np.vstack([test_mat, test_mat])
1279+
assert np.allclose(result, expected)
1280+
1281+
ops = f.maker.fgraph.toposort()
1282+
assert not any(isinstance(n.op, Join) for n in ops)
1283+
1284+
# ---- Case 3: joining same matrix along axis 1 ----
1285+
s = join(1, a, a, a) # (m, 3n)
1286+
f = function([a], s, mode=rewrite_mode)
1287+
1288+
result = f(test_mat)
1289+
expected = np.hstack([test_mat, test_mat, test_mat])
1290+
assert np.allclose(result, expected)
1291+
1292+
ops = f.maker.fgraph.toposort()
1293+
assert not any(isinstance(n.op, Join) for n in ops)
1294+
1295+
# ---- Case 4: different tensors -> should NOT optimize ----
1296+
y = vector("y")
1297+
s = join(0, x, y) # inputs differ
1298+
f = function([x, y], s, mode=rewrite_mode)
1299+
1300+
test_vec1 = np.array([1.0, 2.0], dtype=config.floatX)
1301+
test_vec2 = np.array([3.0, 4.0], dtype=config.floatX)
1302+
result = f(test_vec1, test_vec2)
1303+
expected = np.array([1.0, 2.0, 3.0, 4.0], dtype=config.floatX)
1304+
assert np.allclose(result, expected)
1305+
1306+
# Join should still be present since inputs aren't identical
1307+
ops = f.maker.fgraph.toposort()
1308+
assert any(isinstance(n.op, Join) for n in ops)
1309+
1310+
12501311
def test_local_join_empty():
1251-
# Vector case
1312+
# Vector case - empty tensors should be removed
12521313
empty_vec = np.asarray([], dtype=config.floatX)
12531314
vec = vector("vec")
1254-
s = pt.join(0, vec, vec, empty_vec)
1315+
s = pt.join(0, vec, vec[::-1], empty_vec)
12551316
new_s = rewrite_graph(s)
1256-
assert equal_computations([new_s], [join(0, vec, vec)])
12571317
assert new_s.dtype == s.dtype
1318+
# Verify that empty tensors are removed from the join
1319+
expected = pt.join(0, vec, vec[::-1])
1320+
assert equal_computations([new_s], [expected])
12581321

1259-
# Matrix case
1322+
# Matrix case - empty tensors should be removed
12601323
empty_mat = np.zeros((2, 0), dtype=config.floatX)
12611324
empty_sym_mat = matrix("m", shape=(2, 0))
12621325
mat = matrix("mat", shape=(2, 10))
1263-
s = join(1, empty_mat, mat, empty_sym_mat, mat, mat)
1326+
s = join(1, empty_mat, mat, empty_sym_mat, mat[:, ::-1])
12641327
new_s = rewrite_graph(s)
1265-
assert equal_computations([new_s], [join(1, mat, mat, mat)])
12661328
assert new_s.dtype == s.dtype
1329+
# Verify that empty tensors are removed from the join
1330+
expected = join(1, mat, mat[:, ::-1])
1331+
assert equal_computations([new_s], [expected])
12671332

12681333
# Join can be completely removed, but casting and specify_shape are propagated
12691334
int_mat = matrix("int_mat", dtype=int)

tests/tensor/test_basic.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2020,25 +2020,6 @@ def test_rebroadcast(self):
20202020
# This line used to crash.
20212021
ptb.concatenate([x, -u], axis=2)
20222022

2023-
def test_concatenate_same(self):
2024-
# Test that we can concatenate the same tensor multiple time.
2025-
2026-
# In the past it was broken on the GPU.
2027-
rng = np.random.default_rng(seed=utt.fetch_seed())
2028-
T_shared = self.shared(rng.random((3, 4)).astype(self.floatX))
2029-
Tout = ptb.concatenate([T_shared, T_shared])
2030-
f = function([], Tout, mode=self.mode)
2031-
out = f()
2032-
if config.mode != "FAST_COMPILE":
2033-
assert [
2034-
True
2035-
for node in f.maker.fgraph.toposort()
2036-
if isinstance(node.op, type(self.join_op))
2037-
]
2038-
assert np.allclose(
2039-
out, np.concatenate([T_shared.get_value(), T_shared.get_value()])
2040-
)
2041-
20422023
def test_mixed_ndim_error(self):
20432024
rng = np.random.default_rng(seed=utt.fetch_seed())
20442025
v = self.shared(rng.random(4).astype(self.floatX))

0 commit comments

Comments
 (0)