Skip to content

Commit 00fe425

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas/Fuser] Add basic for some simple output concatenation fusion
PiperOrigin-RevId: 833584905
1 parent 85c9c1d commit 00fe425

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

jax/_src/pallas/fuser/block_spec.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,3 +2375,55 @@ def new_index_map(*args):
23752375
)
23762376

23772377
return pallas_core.BlockSpec(tuple(new_block_shape), new_index_map)
2378+
2379+
2380+
@register_push_block_spec_rule(lax.concatenate_p)
2381+
def _concatenate_push_rule(
2382+
ctx: PushRuleContext,
2383+
*block_specs: pallas_core.BlockSpec,
2384+
dimension: int,
2385+
):
2386+
avals_in = ctx.avals_in
2387+
block_shapes = [
2388+
pallas_core._canonicalize_block_shape(block_spec.block_shape)
2389+
for block_spec in block_specs
2390+
]
2391+
# We only support concatenation if the entirety of the concat dimension is blocked.
2392+
assert all(hasattr(aval_in, 'shape') for aval_in in avals_in)
2393+
if not all(
2394+
block_shape[dimension] == pallas_core.Blocked(avals_in.shape[dimension]) # pytype: disable=attribute-error
2395+
for block_shape, avals_in in zip(block_shapes, avals_in)
2396+
):
2397+
raise NotImplementedError(
2398+
f'concatenate not supported yet: {block_shapes=}, {avals_in=}'
2399+
)
2400+
def _new_index_map(*args):
2401+
all_indices = [block_spec.index_map(*args) for block_spec in block_specs]
2402+
# This is a very important check. We cannot actually construct a single BlockSpec
2403+
# for the output of concatenate if the indices are not identical across all the
2404+
# inputs. This is not something we can always enforce statically, but to be conservative
2405+
# we apply a very aggressive check. We can consider relaxing this later.
2406+
if not all(
2407+
(all_indices[0][i] is all_indices[j][i])
2408+
for i in range(len(all_indices[0]))
2409+
for j in range(len(all_indices))
2410+
):
2411+
raise ValueError(
2412+
'Cannot statically prove that all input blocks to concatenate are the'
2413+
' same.'
2414+
)
2415+
# If all block indices are the same, we are materializing the full concatenation along
2416+
# the concat dimension, so we use index 0.
2417+
base_indices = list(all_indices[0])
2418+
base_indices[dimension] = 0
2419+
return tuple(base_indices)
2420+
2421+
new_block_shape = list(block_specs[0].block_shape)
2422+
# Since the entirety of the concat dimension is materialized in the blocks,
2423+
# the new block size is the sum of the block sizes of the inputs along that
2424+
# dimension.
2425+
new_block_shape[dimension] = sum(
2426+
pallas_core.get_block_size(block_shape[dimension])
2427+
for block_shape in block_shapes
2428+
)
2429+
return pallas_core.BlockSpec(tuple(new_block_shape), _new_index_map)

tests/pallas/fuser_block_spec_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,46 @@ def f(x):
14391439
self.assertEqual(out_block_spec.block_shape, (128, 256))
14401440
self.assertEqual(out_block_spec.index_map(1), (0, 1))
14411441

1442+
def test_concatenate_push(self):
1443+
def f(x1, x2):
1444+
return jnp.concatenate((x1, x2), axis=0)
1445+
1446+
x_type = jax.ShapeDtypeStruct((512,), jnp.float32)
1447+
block_spec = pl.BlockSpec((128,), lambda i: (i,))
1448+
with self.assertRaisesRegex(
1449+
NotImplementedError, 'concatenate not supported yet'
1450+
):
1451+
block_spec_lib.push_block_spec(f, block_spec, block_spec)(x_type, x_type)
1452+
x_type = jax.ShapeDtypeStruct((512,), jnp.float32)
1453+
block_spec = pl.BlockSpec((512,), lambda i: (i,))
1454+
out_block_spec = block_spec_lib.push_block_spec(f, block_spec, block_spec)(
1455+
x_type, x_type
1456+
)
1457+
self.assertEqual(out_block_spec.block_shape, (1024,))
1458+
self.assertEqual(out_block_spec.index_map(0), (0,))
1459+
1460+
def f(x1, x2):
1461+
return jnp.stack([x1, x2], axis=0)
1462+
1463+
x_type = jax.ShapeDtypeStruct((512,), jnp.float32)
1464+
block_spec = pl.BlockSpec((128,), lambda i: (i,))
1465+
out_block_spec = block_spec_lib.push_block_spec(f, block_spec, block_spec)(
1466+
x_type, x_type
1467+
)
1468+
self.assertEqual(out_block_spec.block_shape, (2, 128))
1469+
self.assertEqual(out_block_spec.index_map(3), (0, 3))
1470+
1471+
def f(x1, x2):
1472+
return jnp.stack([x1, x2], axis=1)
1473+
1474+
x_type = jax.ShapeDtypeStruct((512,), jnp.float32)
1475+
block_spec = pl.BlockSpec((128,), lambda i: (i,))
1476+
out_block_spec = block_spec_lib.push_block_spec(f, block_spec, block_spec)(
1477+
x_type, x_type
1478+
)
1479+
self.assertEqual(out_block_spec.block_shape, (128, 2))
1480+
self.assertEqual(out_block_spec.index_map(3), (3, 0))
1481+
14421482

14431483
if __name__ == '__main__':
14441484
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)