Skip to content

Commit 051687d

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas] pallas_call_p is now parameterized by a mesh
The mesh is necessary to add support for clusters to the Mosaic GPU backend. PiperOrigin-RevId: 737792129
1 parent b496613 commit 051687d

File tree

10 files changed

+133
-38
lines changed

10 files changed

+133
-38
lines changed

jax/_src/pallas/core.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Module for pallas-core functionality."""
1616
from __future__ import annotations
1717

18+
import collections
1819
from collections.abc import Callable, Iterable, Iterator, Sequence
1920
import contextlib
2021
import copy
@@ -1068,16 +1069,26 @@ def _core_map_abstract_eval(*args, jaxpr, mesh, **_):
10681069
return [], effs
10691070

10701071

1072+
class Mesh(Protocol):
1073+
1074+
@property
1075+
def backend(self) -> str:
1076+
...
1077+
1078+
@property
1079+
def shape(self) -> collections.OrderedDict[object, int]:
1080+
...
1081+
1082+
10711083
_core_map_mesh_rules: dict[type[Any], Callable[..., Any]] = {}
10721084

10731085

10741086
def default_mesh_discharge_rule(
10751087
in_avals,
10761088
out_avals,
10771089
*args,
1078-
grid,
1090+
mesh,
10791091
compiler_params,
1080-
backend,
10811092
jaxpr,
10821093
debug,
10831094
interpret,
@@ -1100,19 +1111,22 @@ def body(*args):
11001111
if isinstance(eff, state_types.WriteEffect)
11011112
)
11021113
any_spec = BlockSpec(memory_space=MemorySpace.ANY)
1114+
grid_spec = GridSpec(
1115+
grid=tuple(mesh.shape.items()),
1116+
in_specs=[any_spec] * len(in_avals),
1117+
out_specs=[any_spec] * len(modified_idxs),
1118+
)
11031119
from jax._src.pallas import pallas_call # Avoid circular dependency.
1104-
outs = pallas_call.pallas_call(
1120+
outs = pallas_call._pallas_call(
11051121
body,
11061122
name=name,
11071123
out_shape=[in_avals[idx] for idx in modified_idxs],
1108-
in_specs=[any_spec] * len(in_avals),
1109-
out_specs=[any_spec] * len(modified_idxs),
11101124
input_output_aliases={
11111125
in_idx: out_idx for out_idx, in_idx in enumerate(modified_idxs)
11121126
},
1113-
grid=grid,
1127+
grid_spec=grid_spec,
1128+
mesh=mesh,
11141129
compiler_params=compiler_params,
1115-
backend=backend,
11161130
interpret=interpret,
11171131
debug=debug,
11181132
cost_estimate=cost_estimate,

jax/_src/pallas/hlo_interpreter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,12 @@ def pallas_call_hlo_interpret(
340340
debug: bool,
341341
input_output_aliases: tuple[tuple[int, int], ...],
342342
grid_mapping: GridMapping,
343+
mesh: pallas_core.Mesh | None,
343344
compiler_params: Any,
344345
cost_estimate: CostEstimate,
345346
out_avals: tuple[jax_core.AbstractValue, ...],
346347
):
347-
del compiler_params, cost_estimate, out_avals
348+
del mesh, compiler_params, cost_estimate, out_avals
348349
debug_info = jaxpr.debug_info
349350
# If we're in interpret mode, we *scan* over the grid and eval the
350351
# discharged jaxpr.

jax/_src/pallas/mosaic/core.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,10 @@ class TensorCoreMesh:
211211
devices: np.ndarray
212212
axis_names: Sequence[str]
213213

214+
@property
215+
def backend(self) -> str:
216+
return "mosaic_tpu"
217+
214218
@property
215219
def shape(self):
216220
return collections.OrderedDict(zip(self.axis_names, self.devices.shape))
@@ -259,7 +263,6 @@ def _tensorcore_mesh_discharge_rule(
259263
compiler_params = TPUCompilerParams()
260264
if len(mesh.shape) > 1:
261265
raise NotImplementedError("Mesh must be 1D")
262-
core_axis_name, num_cores = list(mesh.shape.items())[0]
263266
if compiler_params.dimension_semantics is not None:
264267
raise ValueError(
265268
"dimension_semantics must be None for TensorCoreMesh"
@@ -269,13 +272,12 @@ def _tensorcore_mesh_discharge_rule(
269272
out_avals,
270273
*args,
271274
jaxpr=jaxpr,
272-
grid=((core_axis_name, num_cores),),
275+
mesh=mesh,
273276
compiler_params=compiler_params.replace(
274277
dimension_semantics=(PARALLEL,)
275278
),
276279
debug=debug,
277280
interpret=interpret,
278-
backend="mosaic_tpu",
279281
cost_estimate=cost_estimate,
280282
name=name,
281283
)

jax/_src/pallas/mosaic/interpret.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1351,12 +1351,13 @@ def interpret_pallas_call(
13511351
debug: bool,
13521352
input_output_aliases: tuple[tuple[int, int], ...],
13531353
grid_mapping: GridMapping,
1354+
mesh: pallas_core.Mesh | None,
13541355
compiler_params: Any,
13551356
cost_estimate: CostEstimate,
13561357
out_avals: tuple[jax_core.AbstractValue, ...],
13571358
interpret_params: TPUInterpretParams,
13581359
):
1359-
del debug, cost_estimate, out_avals
1360+
del debug, mesh, cost_estimate, out_avals
13601361

13611362
# args contains: *dynamic_grid_sizes, *index, *inputs. (No consts?)
13621363
dynamic_grid_args, scalars, input_args = split_list(

jax/_src/pallas/mosaic/pallas_call_registration.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def pallas_call_tpu_lowering_rule(
108108
*in_nodes,
109109
jaxpr: jax_core.Jaxpr,
110110
grid_mapping: core.GridMapping,
111+
mesh: pallas_core.Mesh | None,
111112
input_output_aliases: tuple[tuple[int, int], ...],
112113
debug: bool,
113114
interpret: bool,
@@ -116,7 +117,8 @@ def pallas_call_tpu_lowering_rule(
116117
out_avals: tuple[jax_core.AbstractValue, ...],
117118
):
118119
"""Lowers a pallas_call to a Mosaic TPU custom call."""
119-
del interpret
120+
del mesh, interpret # Unused.
121+
120122
debug_info = jaxpr._debug_info
121123
if debug:
122124
print(f"\nThe kernel jaxpr for pallas_call {debug_info.func_src_info}:")
@@ -126,11 +128,11 @@ def pallas_call_tpu_lowering_rule(
126128
else:
127129
mosaic_params = {}
128130

129-
mesh = None
131+
jax_mesh = None
130132
axis_context = ctx.module_context.axis_context
131133
if axis_context is not None:
132134
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
133-
mesh = axis_context.mesh
135+
jax_mesh = axis_context.mesh
134136
mlir_ctx = mlir.JaxIrContext()
135137
mlir_ctx.append_dialect_registry(mlir.upstream_dialects)
136138
mlir_ctx.load_all_available_dialects()
@@ -147,7 +149,7 @@ def lower_module(for_verification: bool):
147149
grid_mapping,
148150
jaxpr,
149151
dimension_semantics=dimension_semantics,
150-
mesh=mesh,
152+
mesh=jax_mesh,
151153
for_verification=for_verification,
152154
dynamic_shape_replacement_enabled=pallas_core.dynamic_shapes_export_enabled(),
153155
)
@@ -164,11 +166,11 @@ def lower_module(for_verification: bool):
164166
)
165167

166168
if promela_dump_path := _DUMP_PROMELA_TO.value:
167-
num_devices = 1 if mesh is None else mesh.devices.size
169+
num_devices = 1 if jax_mesh is None else jax_mesh.devices.size
168170
num_cores = (
169171
jax.devices()[0].num_cores
170-
if mesh is None
171-
else mesh.devices[0].num_cores
172+
if jax_mesh is None
173+
else jax_mesh.devices[0].num_cores
172174
)
173175
verification_module, _ = lower_module(for_verification=True)
174176
model = verification.export_promela_model(

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import abc
2020
import collections
21-
from collections.abc import Sequence
21+
from collections.abc import Iterable, Sequence
2222
import dataclasses
2323
import enum
2424
import itertools as it
@@ -519,9 +519,16 @@ def __post_init__(self):
519519
)
520520

521521
@property
522-
def shape(self):
522+
def backend(self) -> str:
523+
return "mosaic_gpu"
524+
525+
@property
526+
def shape(self) -> collections.OrderedDict[object, int]:
527+
pairs: Iterable[tuple[object, int]]
523528
if self.num_threads is not None:
524-
pairs = zip(self.axis_names, (*self.grid, *self.cluster, self.num_threads))
529+
pairs = zip(
530+
self.axis_names, (*self.grid, *self.cluster, self.num_threads)
531+
)
525532
else:
526533
pairs = tuple(
527534
zip(
@@ -563,8 +570,7 @@ def _gpu_mesh_discharge_rule(
563570
out_avals,
564571
*args,
565572
jaxpr=jaxpr,
566-
grid=tuple(mesh.shape.items()),
567-
backend="mosaic_gpu",
573+
mesh=mesh,
568574
compiler_params=compiler_params,
569575
debug=debug,
570576
interpret=interpret,

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ def index_map(*indices):
450450

451451
def lower_pipelined_jaxpr_to_module(
452452
grid_mapping: pallas_core.GridMapping,
453+
mesh: pallas_core.Mesh | None,
453454
jaxpr: jax_core.Jaxpr,
454455
compiler_params: dict[str, Any],
455456
cost_estimate: pallas_core.CostEstimate | None,
@@ -473,7 +474,10 @@ def lower_pipelined_jaxpr_to_module(
473474
block_mappings, [grid_mapping.num_inputs]
474475
)
475476

476-
if grid_mapping.grid_names: # Last dim corresponds to the warpgroup count
477+
if mesh is not None:
478+
assert isinstance(mesh, gpu_core.GPUMesh)
479+
if mesh and mesh.num_threads is not None:
480+
# Last dim corresponds to the warpgroup count.
477481
block = (128 * grid_mapping.grid[-1], 1, 1)
478482
grid = grid_mapping.grid[:-1]
479483
else:
@@ -566,6 +570,7 @@ def body_fn(*refs):
566570
parallel_grid,
567571
grid_mapping.grid_names,
568572
block,
573+
mesh.cluster if mesh is not None else (),
569574
[bm.array_shape_dtype for bm in in_block_mappings],
570575
[bm.array_shape_dtype for bm in out_block_mappings],
571576
new_jaxpr,
@@ -578,6 +583,7 @@ def lower_jaxpr_to_module(
578583
grid: Sequence[int],
579584
grid_names: Sequence[str],
580585
block: Sequence[int],
586+
cluster: Sequence[int],
581587
in_shapes: Sequence[jax.ShapeDtypeStruct],
582588
out_shapes: Sequence[jax.ShapeDtypeStruct],
583589
jaxpr: jax_core.Jaxpr,
@@ -640,7 +646,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
640646
mgpu_core._lower_as_gpu_kernel(
641647
body,
642648
grid=parallel_grid,
643-
cluster=(),
649+
cluster=cluster,
644650
block=block,
645651
in_shapes=in_shapes,
646652
out_shape=out_shapes,

jax/_src/pallas/mosaic_gpu/pallas_call_registration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def pallas_call_lowering(
3838
debug: bool,
3939
input_output_aliases: tuple[tuple[int, int], ...],
4040
grid_mapping: pallas_core.GridMapping,
41+
mesh: pallas_core.Mesh | None,
4142
compiler_params: dict[str, Any],
4243
cost_estimate: pallas_core.CostEstimate | None,
4344
out_avals: tuple[jax_core.AbstractValue, ...],
@@ -63,6 +64,7 @@ def pallas_call_lowering(
6364

6465
lowering_result = lowering.lower_pipelined_jaxpr_to_module(
6566
grid_mapping,
67+
mesh,
6668
jaxpr,
6769
compiler_params,
6870
cost_estimate,

0 commit comments

Comments
 (0)