@@ -702,24 +702,6 @@ def _eltwise_usage_rule(
702702 return [used_out ]
703703
704704
705- def register_eltwise_rule (prim : core .Primitive ):
706- register_pull_block_spec_rule (prim )(
707- functools .partial (_eltwise_pull_rule , prim )
708- )
709- register_usage_rule (prim )(functools .partial (_eltwise_usage_rule , prim ))
710- register_eval_rule (prim )(functools .partial (_eltwise_eval_rule , prim ))
711-
712-
713- register_eltwise_rule (lax .exp_p )
714- register_eltwise_rule (lax .tanh_p )
715- register_eltwise_rule (lax .sin_p )
716- register_eltwise_rule (lax .cos_p )
717- register_eltwise_rule (lax .sqrt_p )
718- register_eltwise_rule (lax .rsqrt_p )
719- register_eltwise_rule (lax .log_p )
720- register_eltwise_rule (lax .integer_pow_p )
721-
722-
723705def _bcast_block_spec (block_spec : pallas_core .BlockSpec , i : int ) -> pallas_core .BlockSpec :
724706 def new_index_map (i , * args ):
725707 idx = block_spec .index_map (* args )
@@ -1332,22 +1314,19 @@ def wrapper(*args, **kwargs):
13321314 flat_block_specs , in_tree_ = tree_util .tree_flatten (
13331315 (in_spec_args , in_spec_kwargs )
13341316 )
1335- jaxpr , values , in_tree , out_tree = _make_jaxpr (f , * args , ** kwargs )
1317+ jaxpr , _ , in_tree , out_tree = _make_jaxpr (f , * args , ** kwargs )
13361318 if in_tree != in_tree_ :
13371319 raise ValueError (f'Expected { in_tree } PyTree, got { in_tree_ } ' )
1338- return _push_block_spec (jaxpr , values , in_tree , out_tree , flat_block_specs )
1320+ out_bs = _push_block_spec_jaxpr (jaxpr , * flat_block_specs )
1321+ return tree_util .tree_unflatten (out_tree , out_bs )
13391322
13401323 return wrapper
13411324
13421325
1343- def _push_block_spec (
1326+ def _push_block_spec_jaxpr (
13441327 jaxpr : core .Jaxpr ,
1345- values : tuple [Any , ...],
1346- in_tree : Any ,
1347- out_tree : Any ,
1348- flat_block_specs ,
1328+ * flat_block_specs ,
13491329) -> tuple [pallas_core .BlockSpec , ...]:
1350- del values , in_tree
13511330 num_inputs = len (jaxpr .invars )
13521331 if len (flat_block_specs ) != num_inputs :
13531332 raise ValueError (
@@ -1402,7 +1381,7 @@ def _write_block_spec(
14021381 )
14031382 if any (bs is pallas_core .no_block_spec for bs in out_block_specs ):
14041383 raise ValueError ('No block spec found for output' )
1405- return tree_util . tree_unflatten ( out_tree , out_block_specs )
1384+ return out_block_specs # pytype: disable=bad-return-type
14061385
14071386
14081387push_block_spec_rules : dict [core .Primitive , PushBlockSpecRuleFn ] = {}
@@ -1467,17 +1446,16 @@ def _binop_push_rule(
14671446register_binop_push_rule (ad_util .add_any_p )
14681447
14691448
1470- def _elementwise_op_push_rule (
1471- ctx : PullRuleContext , block_spec : pallas_core .BlockSpec
1449+ def _eltwise_push_rule (
1450+ prim : core .Primitive ,
1451+ ctx : PullRuleContext ,
1452+ block_spec : pallas_core .BlockSpec ,
1453+ ** params ,
14721454) -> pallas_core .BlockSpec :
1473- del ctx
1455+ del prim , ctx , params
14741456 return block_spec
14751457
14761458
1477- register_push_block_spec_rule (lax .exp_p )(_elementwise_op_push_rule )
1478- register_push_block_spec_rule (lax .tanh_p )(_elementwise_op_push_rule )
1479-
1480-
14811459@register_push_block_spec_rule (lax .transpose_p )
14821460def _transpose_push_rule (
14831461 ctx : PushRuleContext ,
@@ -1511,3 +1489,40 @@ def _convert_element_type_push_rule(
15111489):
15121490 del ctx , new_dtype , weak_type , sharding
15131491 return block_spec
1492+
1493+
1494+ @register_push_block_spec_rule (custom_derivatives .custom_jvp_call_p )
1495+ def _custom_jvp_call_push_rule (
1496+ ctx , * block_specs , call_jaxpr : core .ClosedJaxpr , ** _
1497+ ):
1498+ assert not call_jaxpr .consts
1499+ return _push_block_spec_jaxpr (call_jaxpr .jaxpr , * block_specs )
1500+
1501+
1502+ @register_push_block_spec_rule (pjit .pjit_p )
1503+ def _pjit_push_rule (
1504+ ctx , * block_specs , jaxpr : core .ClosedJaxpr , ** _
1505+ ):
1506+ assert not jaxpr .consts
1507+ return _push_block_spec_jaxpr (jaxpr .jaxpr , * block_specs )
1508+
1509+
1510+ def register_eltwise_rule (prim : core .Primitive ):
1511+ register_pull_block_spec_rule (prim )(
1512+ functools .partial (_eltwise_pull_rule , prim )
1513+ )
1514+ register_usage_rule (prim )(functools .partial (_eltwise_usage_rule , prim ))
1515+ register_eval_rule (prim )(functools .partial (_eltwise_eval_rule , prim ))
1516+ register_push_block_spec_rule (prim )(
1517+ functools .partial (_eltwise_push_rule , prim )
1518+ )
1519+
1520+
1521+ register_eltwise_rule (lax .exp_p )
1522+ register_eltwise_rule (lax .tanh_p )
1523+ register_eltwise_rule (lax .sin_p )
1524+ register_eltwise_rule (lax .cos_p )
1525+ register_eltwise_rule (lax .sqrt_p )
1526+ register_eltwise_rule (lax .rsqrt_p )
1527+ register_eltwise_rule (lax .log_p )
1528+ register_eltwise_rule (lax .integer_pow_p )
0 commit comments