Skip to content

Commit 00d9f45

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas/Fuser] Add support for custom_call_jvp/pjit to push_block_spec
PiperOrigin-RevId: 733122108
1 parent d32e282 commit 00d9f45

File tree

2 files changed

+80
-34
lines changed

2 files changed

+80
-34
lines changed

jax/_src/pallas/fuser/block_spec.py

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -702,24 +702,6 @@ def _eltwise_usage_rule(
702702
return [used_out]
703703

704704

705-
def register_eltwise_rule(prim: core.Primitive):
706-
register_pull_block_spec_rule(prim)(
707-
functools.partial(_eltwise_pull_rule, prim)
708-
)
709-
register_usage_rule(prim)(functools.partial(_eltwise_usage_rule, prim))
710-
register_eval_rule(prim)(functools.partial(_eltwise_eval_rule, prim))
711-
712-
713-
register_eltwise_rule(lax.exp_p)
714-
register_eltwise_rule(lax.tanh_p)
715-
register_eltwise_rule(lax.sin_p)
716-
register_eltwise_rule(lax.cos_p)
717-
register_eltwise_rule(lax.sqrt_p)
718-
register_eltwise_rule(lax.rsqrt_p)
719-
register_eltwise_rule(lax.log_p)
720-
register_eltwise_rule(lax.integer_pow_p)
721-
722-
723705
def _bcast_block_spec(block_spec: pallas_core.BlockSpec, i: int) -> pallas_core.BlockSpec:
724706
def new_index_map(i, *args):
725707
idx = block_spec.index_map(*args)
@@ -1332,22 +1314,19 @@ def wrapper(*args, **kwargs):
13321314
flat_block_specs, in_tree_ = tree_util.tree_flatten(
13331315
(in_spec_args, in_spec_kwargs)
13341316
)
1335-
jaxpr, values, in_tree, out_tree = _make_jaxpr(f, *args, **kwargs)
1317+
jaxpr, _, in_tree, out_tree = _make_jaxpr(f, *args, **kwargs)
13361318
if in_tree != in_tree_:
13371319
raise ValueError(f'Expected {in_tree} PyTree, got {in_tree_}')
1338-
return _push_block_spec(jaxpr, values, in_tree, out_tree, flat_block_specs)
1320+
out_bs = _push_block_spec_jaxpr(jaxpr, *flat_block_specs)
1321+
return tree_util.tree_unflatten(out_tree, out_bs)
13391322

13401323
return wrapper
13411324

13421325

1343-
def _push_block_spec(
1326+
def _push_block_spec_jaxpr(
13441327
jaxpr: core.Jaxpr,
1345-
values: tuple[Any, ...],
1346-
in_tree: Any,
1347-
out_tree: Any,
1348-
flat_block_specs,
1328+
*flat_block_specs,
13491329
) -> tuple[pallas_core.BlockSpec, ...]:
1350-
del values, in_tree
13511330
num_inputs = len(jaxpr.invars)
13521331
if len(flat_block_specs) != num_inputs:
13531332
raise ValueError(
@@ -1402,7 +1381,7 @@ def _write_block_spec(
14021381
)
14031382
if any(bs is pallas_core.no_block_spec for bs in out_block_specs):
14041383
raise ValueError('No block spec found for output')
1405-
return tree_util.tree_unflatten(out_tree, out_block_specs)
1384+
return out_block_specs # pytype: disable=bad-return-type
14061385

14071386

14081387
push_block_spec_rules: dict[core.Primitive, PushBlockSpecRuleFn] = {}
@@ -1467,17 +1446,16 @@ def _binop_push_rule(
14671446
register_binop_push_rule(ad_util.add_any_p)
14681447

14691448

1470-
def _elementwise_op_push_rule(
1471-
ctx: PullRuleContext, block_spec: pallas_core.BlockSpec
1449+
def _eltwise_push_rule(
1450+
prim: core.Primitive,
1451+
ctx: PullRuleContext,
1452+
block_spec: pallas_core.BlockSpec,
1453+
**params,
14721454
) -> pallas_core.BlockSpec:
1473-
del ctx
1455+
del prim, ctx, params
14741456
return block_spec
14751457

14761458

1477-
register_push_block_spec_rule(lax.exp_p)(_elementwise_op_push_rule)
1478-
register_push_block_spec_rule(lax.tanh_p)(_elementwise_op_push_rule)
1479-
1480-
14811459
@register_push_block_spec_rule(lax.transpose_p)
14821460
def _transpose_push_rule(
14831461
ctx: PushRuleContext,
@@ -1511,3 +1489,40 @@ def _convert_element_type_push_rule(
15111489
):
15121490
del ctx, new_dtype, weak_type, sharding
15131491
return block_spec
1492+
1493+
1494+
@register_push_block_spec_rule(custom_derivatives.custom_jvp_call_p)
1495+
def _custom_jvp_call_push_rule(
1496+
ctx, *block_specs, call_jaxpr: core.ClosedJaxpr, **_
1497+
):
1498+
assert not call_jaxpr.consts
1499+
return _push_block_spec_jaxpr(call_jaxpr.jaxpr, *block_specs)
1500+
1501+
1502+
@register_push_block_spec_rule(pjit.pjit_p)
1503+
def _pjit_push_rule(
1504+
ctx, *block_specs, jaxpr: core.ClosedJaxpr, **_
1505+
):
1506+
assert not jaxpr.consts
1507+
return _push_block_spec_jaxpr(jaxpr.jaxpr, *block_specs)
1508+
1509+
1510+
def register_eltwise_rule(prim: core.Primitive):
1511+
register_pull_block_spec_rule(prim)(
1512+
functools.partial(_eltwise_pull_rule, prim)
1513+
)
1514+
register_usage_rule(prim)(functools.partial(_eltwise_usage_rule, prim))
1515+
register_eval_rule(prim)(functools.partial(_eltwise_eval_rule, prim))
1516+
register_push_block_spec_rule(prim)(
1517+
functools.partial(_eltwise_push_rule, prim)
1518+
)
1519+
1520+
1521+
register_eltwise_rule(lax.exp_p)
1522+
register_eltwise_rule(lax.tanh_p)
1523+
register_eltwise_rule(lax.sin_p)
1524+
register_eltwise_rule(lax.cos_p)
1525+
register_eltwise_rule(lax.sqrt_p)
1526+
register_eltwise_rule(lax.rsqrt_p)
1527+
register_eltwise_rule(lax.log_p)
1528+
register_eltwise_rule(lax.integer_pow_p)

tests/pallas/fuser_block_spec_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,5 +772,36 @@ def f(x):
772772
)
773773

774774

775+
class PushBlockSpecTest(parameterized.TestCase):
776+
777+
def setUp(self):
778+
super().setUp()
779+
if config.enable_x64.value:
780+
self.skipTest('x64 not supported')
781+
782+
def test_jit(self):
783+
784+
def f(x):
785+
return jax.jit(jnp.sin)(x)
786+
787+
block_spec = pl.BlockSpec(
788+
(None, 1, 128, 128), lambda i, j, k, l, _: (i, l, k, j)
789+
)
790+
x_type = jax.ShapeDtypeStruct((1, 1, 512, 512), jnp.float32)
791+
out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type)
792+
self.assertEqual(out_block_spec.block_shape, block_spec.block_shape)
793+
794+
def test_custom_jvp(self):
795+
def f(x):
796+
return jax.nn.relu(x)
797+
798+
x_type = jax.ShapeDtypeStruct((1, 1, 512, 512), jnp.float32)
799+
block_spec = pl.BlockSpec(
800+
(None, 1, 128, 128), lambda i, j, k, l, _: (i, l, k, j)
801+
)
802+
out_block_spec = block_spec_lib.push_block_spec(f, block_spec)(x_type)
803+
self.assertEqual(out_block_spec.block_shape, block_spec.block_shape)
804+
805+
775806
if __name__ == '__main__':
776807
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)