Skip to content

Commit 4f93563

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[pallas] Support for setting explicit backends to pallas_call.
PiperOrigin-RevId: 688511303
1 parent 2db03ba commit 4f93563

File tree

3 files changed

+42
-8
lines changed

3 files changed

+42
-8
lines changed

jax/_src/pallas/mosaic/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def body(*args):
270270
compiler_params=dict(
271271
mosaic=dict(dimension_semantics=("parallel",)),
272272
),
273+
backend="mosaic_tpu",
273274
)(*args)
274275
return out, ()
275276

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,7 @@ def body(*args):
512512
out_specs=[any_spec] * len(in_avals),
513513
input_output_aliases={i: i for i in range(len(in_avals))},
514514
grid=tuple(mesh.shape.items()),
515+
backend="mosaic_gpu",
515516
)(*args)
516517
return out, ()
517518

jax/_src/pallas/pallas_call.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import dataclasses
2020
from functools import partial, reduce
2121
import itertools
22-
from typing import Any
22+
from typing import Any, Literal
2323

2424
import jax
2525
from jax import lax
@@ -362,6 +362,7 @@ def _pallas_call_jvp_rule(
362362
compiler_params: Any,
363363
cost_estimate: CostEstimate | None,
364364
out_avals: tuple[jax_core.AbstractValue, ...],
365+
backend: _Backend | None,
365366
):
366367
if grid_mapping.num_dynamic_grid_bounds:
367368
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
@@ -425,7 +426,8 @@ def _pallas_call_jvp_rule(
425426
input_output_aliases=(),
426427
compiler_params=compiler_params,
427428
cost_estimate=jvp_cost_estimate,
428-
out_avals=(*out_avals, *out_avals)
429+
out_avals=(*out_avals, *out_avals),
430+
backend=backend,
429431
)
430432
out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2])
431433
return out_primals, out_tangents
@@ -560,6 +562,7 @@ def _batch_with_explicit_loop(
560562
compiler_params: Any,
561563
cost_estimate: CostEstimate | None,
562564
out_avals: tuple[jax_core.AbstractValue, ...],
565+
backend: _Backend | None,
563566
):
564567
"""Batch the pallas_call by calling it in loop over the batch size.
565568
@@ -627,6 +630,7 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]:
627630
compiler_params=compiler_params,
628631
cost_estimate=cost_estimate,
629632
out_avals=out_avals,
633+
backend=backend,
630634
)
631635
for i, batch_out_array in enumerate(batch_out):
632636
state[i] = jax.lax.dynamic_update_index_in_dim(
@@ -656,6 +660,7 @@ def _pallas_call_batching_rule(
656660
compiler_params: Any,
657661
cost_estimate: CostEstimate | None,
658662
out_avals: tuple[jax_core.AbstractValue, ...],
663+
backend: _Backend | None,
659664
):
660665
def _maybe_squeeze_out_bdim(
661666
x: jax.Array, bdim: int | batching.NotMapped
@@ -688,6 +693,7 @@ def get_size(i, x, d):
688693
compiler_params=compiler_params,
689694
cost_estimate=cost_estimate,
690695
out_avals=out_avals,
696+
backend=backend,
691697
)
692698
return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out)
693699

@@ -721,6 +727,7 @@ def get_size(i, x, d):
721727
compiler_params=compiler_params,
722728
cost_estimate=cost_estimate,
723729
out_avals=out_avals,
730+
backend=backend,
724731
)
725732
else:
726733
pass # No dynamic grid dimensions
@@ -755,6 +762,7 @@ def get_size(i, x, d):
755762
compiler_params=compiler_params,
756763
cost_estimate=cost_estimate,
757764
out_avals=out_avals,
765+
backend=backend,
758766
)
759767

760768
if not dims:
@@ -1128,6 +1136,7 @@ def index_rewrite_kernel(*indexer_args):
11281136
compiler_params=compiler_params,
11291137
cost_estimate=batched_cost_estimate,
11301138
out_avals=batched_out_avals,
1139+
backend=backend,
11311140
)
11321141
return out, (0,) * len(out)
11331142

@@ -1441,9 +1450,15 @@ def _unsupported_lowering_error(platform: str) -> Exception:
14411450
" https://jax.readthedocs.io/en/latest/installation.html."
14421451
)
14431452

1453+
_Backend = Literal["mosaic_tpu", "triton", "mosaic_gpu"]
1454+
14441455

14451456
def _pallas_call_lowering(
1446-
ctx: mlir.LoweringRuleContext, *in_nodes, interpret: bool, **params
1457+
ctx: mlir.LoweringRuleContext,
1458+
*in_nodes,
1459+
interpret: bool,
1460+
backend: _Backend | None,
1461+
**params,
14471462
):
14481463
if params['jaxpr'].constvars:
14491464
raise ValueError('Cannot lower a pallas_call with constants.')
@@ -1460,6 +1475,8 @@ def cpu_lowering(ctx: mlir.LoweringRuleContext,
14601475
def tpu_lowering(ctx: mlir.LoweringRuleContext,
14611476
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
14621477
**params):
1478+
if backend and backend != "mosaic_tpu":
1479+
raise ValueError("Only mosaic backend supported for TPU")
14631480
if mosaic_tpu_backend is None:
14641481
raise _unsupported_lowering_error("tpu")
14651482
return mosaic_tpu_backend.pallas_call_tpu_lowering_rule(
@@ -1470,12 +1487,21 @@ def gpu_lowering(ctx: mlir.LoweringRuleContext,
14701487
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
14711488
**params):
14721489
try:
1473-
if _PALLAS_USE_MOSAIC_GPU.value:
1474-
from jax._src.pallas.mosaic_gpu import pallas_call_registration
1475-
else:
1476-
from jax._src.pallas.triton import pallas_call_registration # type: ignore
1477-
except ImportError:
1490+
match backend:
1491+
case "mosaic_gpu":
1492+
from jax._src.pallas.mosaic_gpu import pallas_call_registration
1493+
case "triton":
1494+
from jax._src.pallas.triton import pallas_call_registration # type: ignore
1495+
case None:
1496+
if _PALLAS_USE_MOSAIC_GPU.value:
1497+
from jax._src.pallas.mosaic_gpu import pallas_call_registration
1498+
else:
1499+
from jax._src.pallas.triton import pallas_call_registration # type: ignore
1500+
case _:
1501+
raise ValueError(f"Unsupported backend: {backend}")
1502+
except ImportError as e:
14781503
raise _unsupported_lowering_error("gpu")
1504+
14791505
return pallas_call_registration.pallas_call_lowering(
14801506
ctx, *in_nodes, **params
14811507
)
@@ -1544,6 +1570,7 @@ def _pallas_call_state_discharge_rule(
15441570
compiler_params: Any,
15451571
cost_estimate: CostEstimate | None,
15461572
out_avals: tuple[jax_core.AbstractValue, ...],
1573+
backend: _Backend | None = None
15471574
):
15481575
del avals_out
15491576
assert all(isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars)
@@ -1645,6 +1672,7 @@ def _rewritten_body(*args):
16451672
compiler_params=compiler_params,
16461673
cost_estimate=cost_estimate,
16471674
out_avals=new_out_avals,
1675+
backend=backend,
16481676
)
16491677
refs_out, rest = split_list(out_flat, [num_refs])
16501678
updated_vals_in = refs_out + [None] * len(rest_in_avals)
@@ -1666,6 +1694,7 @@ def pallas_call(
16661694
name: str | None = None,
16671695
compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None,
16681696
cost_estimate: CostEstimate | None = None,
1697+
backend: _Backend | None = None,
16691698
) -> Callable[..., Any]:
16701699
"""Invokes a Pallas kernel on some inputs.
16711700
@@ -1715,6 +1744,8 @@ def pallas_call(
17151744
platform is either 'mosaic' or 'triton'. It is also possible
17161745
to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and
17171746
`jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs.
1747+
backend: Optional string literal one of "mosaic_tpu", "triton" or "mosaic_gpu"
1748+
determining the backend to be used. None means let pallas decide.
17181749
17191750
17201751
Returns:
@@ -1857,6 +1888,7 @@ def wrapped(*args):
18571888
input_output_aliases=tuple(input_output_aliases.items()),
18581889
compiler_params=compiler_params,
18591890
cost_estimate=cost_estimate,
1891+
backend=backend,
18601892
)
18611893
out = tree_util.tree_unflatten(out_tree, out_flat)
18621894
return out

0 commit comments

Comments
 (0)