1919import dataclasses
2020from functools import partial , reduce
2121import itertools
22- from typing import Any
22+ from typing import Any , Literal
2323
2424import jax
2525from jax import lax
@@ -362,6 +362,7 @@ def _pallas_call_jvp_rule(
362362 compiler_params : Any ,
363363 cost_estimate : CostEstimate | None ,
364364 out_avals : tuple [jax_core .AbstractValue , ...],
365+ backend : _Backend | None ,
365366):
366367 if grid_mapping .num_dynamic_grid_bounds :
367368 raise NotImplementedError ("interpret with dynamic grid bounds unsupported" )
@@ -425,7 +426,8 @@ def _pallas_call_jvp_rule(
425426 input_output_aliases = (),
426427 compiler_params = compiler_params ,
427428 cost_estimate = jvp_cost_estimate ,
428- out_avals = (* out_avals , * out_avals )
429+ out_avals = (* out_avals , * out_avals ),
430+ backend = backend ,
429431 )
430432 out_primals , out_tangents = split_list (out_flat , [len (out_flat ) // 2 ])
431433 return out_primals , out_tangents
@@ -560,6 +562,7 @@ def _batch_with_explicit_loop(
560562 compiler_params : Any ,
561563 cost_estimate : CostEstimate | None ,
562564 out_avals : tuple [jax_core .AbstractValue , ...],
565+ backend : _Backend | None ,
563566):
564567 """Batch the pallas_call by calling it in loop over the batch size.
565568
@@ -627,6 +630,7 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]:
627630 compiler_params = compiler_params ,
628631 cost_estimate = cost_estimate ,
629632 out_avals = out_avals ,
633+ backend = backend ,
630634 )
631635 for i , batch_out_array in enumerate (batch_out ):
632636 state [i ] = jax .lax .dynamic_update_index_in_dim (
@@ -656,6 +660,7 @@ def _pallas_call_batching_rule(
656660 compiler_params : Any ,
657661 cost_estimate : CostEstimate | None ,
658662 out_avals : tuple [jax_core .AbstractValue , ...],
663+ backend : _Backend | None ,
659664):
660665 def _maybe_squeeze_out_bdim (
661666 x : jax .Array , bdim : int | batching .NotMapped
@@ -688,6 +693,7 @@ def get_size(i, x, d):
688693 compiler_params = compiler_params ,
689694 cost_estimate = cost_estimate ,
690695 out_avals = out_avals ,
696+ backend = backend ,
691697 )
692698 return [jnp .expand_dims (x , 0 ) for x in out ], (0 ,) * len (out )
693699
@@ -721,6 +727,7 @@ def get_size(i, x, d):
721727 compiler_params = compiler_params ,
722728 cost_estimate = cost_estimate ,
723729 out_avals = out_avals ,
730+ backend = backend ,
724731 )
725732 else :
726733 pass # No dynamic grid dimensions
@@ -755,6 +762,7 @@ def get_size(i, x, d):
755762 compiler_params = compiler_params ,
756763 cost_estimate = cost_estimate ,
757764 out_avals = out_avals ,
765+ backend = backend ,
758766 )
759767
760768 if not dims :
@@ -1128,6 +1136,7 @@ def index_rewrite_kernel(*indexer_args):
11281136 compiler_params = compiler_params ,
11291137 cost_estimate = batched_cost_estimate ,
11301138 out_avals = batched_out_avals ,
1139+ backend = backend ,
11311140 )
11321141 return out , (0 ,) * len (out )
11331142
@@ -1441,9 +1450,15 @@ def _unsupported_lowering_error(platform: str) -> Exception:
14411450 " https://jax.readthedocs.io/en/latest/installation.html."
14421451 )
14431452
1453+ _Backend = Literal ["mosaic_tpu" , "triton" , "mosaic_gpu" ]
1454+
14441455
14451456def _pallas_call_lowering (
1446- ctx : mlir .LoweringRuleContext , * in_nodes , interpret : bool , ** params
1457+ ctx : mlir .LoweringRuleContext ,
1458+ * in_nodes ,
1459+ interpret : bool ,
1460+ backend : _Backend | None ,
1461+ ** params ,
14471462):
14481463 if params ['jaxpr' ].constvars :
14491464 raise ValueError ('Cannot lower a pallas_call with constants.' )
@@ -1460,6 +1475,8 @@ def cpu_lowering(ctx: mlir.LoweringRuleContext,
14601475 def tpu_lowering (ctx : mlir .LoweringRuleContext ,
14611476 * in_nodes : mlir .ir .Value | Sequence [mlir .ir .Value ],
14621477 ** params ):
1478+ if backend and backend != "mosaic_tpu" :
1479+ raise ValueError ("Only mosaic backend supported for TPU" )
14631480 if mosaic_tpu_backend is None :
14641481 raise _unsupported_lowering_error ("tpu" )
14651482 return mosaic_tpu_backend .pallas_call_tpu_lowering_rule (
@@ -1470,12 +1487,21 @@ def gpu_lowering(ctx: mlir.LoweringRuleContext,
14701487 * in_nodes : mlir .ir .Value | Sequence [mlir .ir .Value ],
14711488 ** params ):
14721489 try :
1473- if _PALLAS_USE_MOSAIC_GPU .value :
1474- from jax ._src .pallas .mosaic_gpu import pallas_call_registration
1475- else :
1476- from jax ._src .pallas .triton import pallas_call_registration # type: ignore
1477- except ImportError :
1490+ match backend :
1491+ case "mosaic_gpu" :
1492+ from jax ._src .pallas .mosaic_gpu import pallas_call_registration
1493+ case "triton" :
1494+ from jax ._src .pallas .triton import pallas_call_registration # type: ignore
1495+ case None :
1496+ if _PALLAS_USE_MOSAIC_GPU .value :
1497+ from jax ._src .pallas .mosaic_gpu import pallas_call_registration
1498+ else :
1499+ from jax ._src .pallas .triton import pallas_call_registration # type: ignore
1500+ case _:
1501+ raise ValueError (f"Unsupported backend: { backend } " )
1502+ except ImportError as e :
14781503 raise _unsupported_lowering_error ("gpu" )
1504+
14791505 return pallas_call_registration .pallas_call_lowering (
14801506 ctx , * in_nodes , ** params
14811507 )
@@ -1544,6 +1570,7 @@ def _pallas_call_state_discharge_rule(
15441570 compiler_params : Any ,
15451571 cost_estimate : CostEstimate | None ,
15461572 out_avals : tuple [jax_core .AbstractValue , ...],
1573+ backend : _Backend | None = None
15471574):
15481575 del avals_out
15491576 assert all (isinstance (v .aval , state .AbstractRef ) for v in jaxpr .constvars )
@@ -1645,6 +1672,7 @@ def _rewritten_body(*args):
16451672 compiler_params = compiler_params ,
16461673 cost_estimate = cost_estimate ,
16471674 out_avals = new_out_avals ,
1675+ backend = backend ,
16481676 )
16491677 refs_out , rest = split_list (out_flat , [num_refs ])
16501678 updated_vals_in = refs_out + [None ] * len (rest_in_avals )
@@ -1666,6 +1694,7 @@ def pallas_call(
16661694 name : str | None = None ,
16671695 compiler_params : dict [str , Any ] | pallas_core .CompilerParams | None = None ,
16681696 cost_estimate : CostEstimate | None = None ,
1697+ backend : _Backend | None = None ,
16691698) -> Callable [..., Any ]:
16701699 """Invokes a Pallas kernel on some inputs.
16711700
@@ -1715,6 +1744,8 @@ def pallas_call(
17151744 platform is either 'mosaic' or 'triton'. It is also possible
17161745 to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and
17171746 `jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs.
1747+ backend: Optional string literal one of "mosaic_tpu", "triton" or "mosaic_gpu"
1748+ determining the backend to be used. None means let pallas decide.
17181749
17191750
17201751 Returns:
@@ -1857,6 +1888,7 @@ def wrapped(*args):
18571888 input_output_aliases = tuple (input_output_aliases .items ()),
18581889 compiler_params = compiler_params ,
18591890 cost_estimate = cost_estimate ,
1891+ backend = backend ,
18601892 )
18611893 out = tree_util .tree_unflatten (out_tree , out_flat )
18621894 return out
0 commit comments