@@ -1237,33 +1237,98 @@ def test_local_join_1():
12371237 assert len ([n for n in e if isinstance (n .op , Join )]) == 0
12381238 assert f .maker .fgraph .outputs [0 ].dtype == config .floatX
12391239
1240- # test we don't apply when their is 2 inputs
1241- s = join (1 , a , a )
1240+ # Test that join with 2 different inputs remains (not optimized away)
1241+ s = join (1 , a , a [:, :: - 1 ] )
12421242 f = function ([a ], s , mode = rewrite_mode )
1243- val = f ([[1 ]])
1244- assert np .all (val == [[1 ]])
1243+ val = f ([[1 , 2 ]])
1244+ assert np .all (val == [[1 , 2 , 2 , 1 ]]) # joined along axis 1
12451245 e = f .maker .fgraph .toposort ()
1246- assert len ([n for n in e if isinstance (n .op , Join )]) == 1
1246+ assert len ([n for n in e if isinstance (n .op , Join )]) == 1 # join remains
12471247 assert f .maker .fgraph .outputs [0 ].dtype == config .floatX
12481248
12491249
1250+ def test_local_join_to_tile ():
1251+ """Join(axis, x, x, ...) is rewritten to tile(x, reps) with reps[axis] = k.
1252+
1253+ This optimization applies whenever we concatenate the *same* tensor multiple
1254+ times along a given axis. It replaces the Join/concatenate with a Tile op.
1255+ """
1256+
1257+ # ---- Case 1: joining same vector along axis 0 ----
1258+ x = vector ("x" )
1259+ s = join (0 , x , x , x ) # (3n,)
1260+ f = function ([x ], s , mode = rewrite_mode )
1261+
1262+ test_val = np .array ([1.0 , 2.0 ], dtype = config .floatX )
1263+ result = f (test_val )
1264+ expected = np .array ([1.0 , 2.0 , 1.0 , 2.0 , 1.0 , 2.0 ], dtype = config .floatX )
1265+ assert np .allclose (result , expected )
1266+
1267+ # Join should be optimized away
1268+ ops = f .maker .fgraph .toposort ()
1269+ assert not any (isinstance (n .op , Join ) for n in ops )
1270+
1271+ # ---- Case 2: joining same matrix along axis 0 ----
1272+ a = matrix ("a" )
1273+ s = join (0 , a , a ) # (2m, n)
1274+ f = function ([a ], s , mode = rewrite_mode )
1275+
1276+ test_mat = np .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]], dtype = config .floatX )
1277+ result = f (test_mat )
1278+ expected = np .vstack ([test_mat , test_mat ])
1279+ assert np .allclose (result , expected )
1280+
1281+ ops = f .maker .fgraph .toposort ()
1282+ assert not any (isinstance (n .op , Join ) for n in ops )
1283+
1284+ # ---- Case 3: joining same matrix along axis 1 ----
1285+ s = join (1 , a , a , a ) # (m, 3n)
1286+ f = function ([a ], s , mode = rewrite_mode )
1287+
1288+ result = f (test_mat )
1289+ expected = np .hstack ([test_mat , test_mat , test_mat ])
1290+ assert np .allclose (result , expected )
1291+
1292+ ops = f .maker .fgraph .toposort ()
1293+ assert not any (isinstance (n .op , Join ) for n in ops )
1294+
1295+ # ---- Case 4: different tensors -> should NOT optimize ----
1296+ y = vector ("y" )
1297+ s = join (0 , x , y ) # inputs differ
1298+ f = function ([x , y ], s , mode = rewrite_mode )
1299+
1300+ test_vec1 = np .array ([1.0 , 2.0 ], dtype = config .floatX )
1301+ test_vec2 = np .array ([3.0 , 4.0 ], dtype = config .floatX )
1302+ result = f (test_vec1 , test_vec2 )
1303+ expected = np .array ([1.0 , 2.0 , 3.0 , 4.0 ], dtype = config .floatX )
1304+ assert np .allclose (result , expected )
1305+
1306+ # Join should still be present since inputs aren't identical
1307+ ops = f .maker .fgraph .toposort ()
1308+ assert any (isinstance (n .op , Join ) for n in ops )
1309+
1310+
12501311def test_local_join_empty ():
1251- # Vector case
1312+ # Vector case - empty tensors should be removed
12521313 empty_vec = np .asarray ([], dtype = config .floatX )
12531314 vec = vector ("vec" )
1254- s = pt .join (0 , vec , vec , empty_vec )
1315+ s = pt .join (0 , vec , vec [:: - 1 ] , empty_vec )
12551316 new_s = rewrite_graph (s )
1256- assert equal_computations ([new_s ], [join (0 , vec , vec )])
12571317 assert new_s .dtype == s .dtype
1318+ # Verify that empty tensors are removed from the join
1319+ expected = pt .join (0 , vec , vec [::- 1 ])
1320+ assert equal_computations ([new_s ], [expected ])
12581321
1259- # Matrix case
1322+ # Matrix case - empty tensors should be removed
12601323 empty_mat = np .zeros ((2 , 0 ), dtype = config .floatX )
12611324 empty_sym_mat = matrix ("m" , shape = (2 , 0 ))
12621325 mat = matrix ("mat" , shape = (2 , 10 ))
1263- s = join (1 , empty_mat , mat , empty_sym_mat , mat , mat )
1326+ s = join (1 , empty_mat , mat , empty_sym_mat , mat [:, :: - 1 ] )
12641327 new_s = rewrite_graph (s )
1265- assert equal_computations ([new_s ], [join (1 , mat , mat , mat )])
12661328 assert new_s .dtype == s .dtype
1329+ # Verify that empty tensors are removed from the join
1330+ expected = join (1 , mat , mat [:, ::- 1 ])
1331+ assert equal_computations ([new_s ], [expected ])
12671332
12681333 # Join can be completely removed, but casting and specify_shape are propagated
12691334 int_mat = matrix ("int_mat" , dtype = int )
0 commit comments