Skip to content

Commit 15f30a9

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] emit_pipeline now maintains the grid indices
Previously, it was recomputing them at every loop iteration. PiperOrigin-RevId: 695682116
1 parent cb82609 commit 15f30a9

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

jax/_src/pallas/mosaic_gpu/pipeline.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from jax._src.pallas.mosaic_gpu import core as gpu_core
3131
from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives
3232
from jax.experimental import pallas as pl
33+
import jax.numpy as jnp
3334

3435

3536
map = util.safe_map
@@ -72,15 +73,16 @@ def copy_out(self, slot, grid_indices):
7273
)
7374

7475

75-
def make_grid_indices(
76-
step: jax.typing.ArrayLike, grid: Sequence[int]
76+
def _inc_grid_by_1(
77+
indices: tuple[jax.Array, ...], grid: Sequence[int]
7778
) -> tuple[jax.Array, ...]:
78-
# TODO(slebedev): Maintain the grid index through the fori_loop instead.
79-
indices = []
80-
for size in reversed(grid):
81-
indices.append(lax.rem(step, size))
82-
step = lax.div(step, size)
83-
return tuple(reversed(indices))
79+
next_indices = []
80+
carry: bool | jax.Array = True
81+
for idx, size in reversed(list(zip(indices, grid))):
82+
next_idx = lax.select(carry, idx + 1, idx)
83+
carry = next_idx == size
84+
next_indices.append(lax.select(carry, 0, next_idx).astype(idx.dtype))
85+
return tuple(reversed(next_indices))
8486

8587

8688
def emit_pipeline(
@@ -143,15 +145,15 @@ def scoped_pipeline(
143145
):
144146
map(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs)
145147

146-
def loop_body(step, _):
148+
def loop_body(step, carry):
147149
slot = step % max_concurrent_steps
150+
indices, fetch_indices = carry
148151

149152
# Wait for the current GMEM->SMEM copy to complete.
150153
gpu_primitives.barrier_wait(barrier_ref.at[slot])
151154
# Wait for the previous output SMEM->GMEM copy to complete.
152155
gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1)
153156

154-
indices = make_grid_indices(step, grid)
155157
with pallas_core.grid_env(map(pallas_core.GridAxis, indices, grid)):
156158
body(
157159
*(bref.smem_ref.at[slot] for bref in it.chain(in_brefs, out_brefs))
@@ -166,17 +168,19 @@ def loop_body(step, _):
166168
jax.lax.cond(
167169
fetch_step < num_steps,
168170
lambda: map(
169-
lambda bref: bref.copy_in(
170-
fetch_slot, make_grid_indices(fetch_step, grid), barrier_ref
171-
),
171+
lambda bref: bref.copy_in(fetch_slot, fetch_indices, barrier_ref),
172172
in_brefs,
173173
),
174174
lambda: [None] * len(in_brefs),
175175
)
176176

177-
return ()
177+
return _inc_grid_by_1(indices, grid), _inc_grid_by_1(fetch_indices, grid)
178178

179-
lax.fori_loop(0, num_steps, loop_body, ())
179+
indices = (jnp.asarray(0, dtype=lax.dtype(0)),) * len(grid)
180+
fetch_indices = indices
181+
for _ in range(max_concurrent_steps):
182+
fetch_indices = _inc_grid_by_1(fetch_indices, grid)
183+
lax.fori_loop(0, num_steps, loop_body, (indices, fetch_indices))
180184

181185
# Finalize the pipeline.
182186
gpu_primitives.wait_smem_to_gmem(0)

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1184,7 +1184,11 @@ def select(self, on_true, on_false):
11841184
or ir.IntegerType(self.mlir_dtype).width != 1
11851185
):
11861186
raise NotImplementedError
1187-
return self._pointwise(arith.select, on_true, on_false)
1187+
# We change the receiver here, because the return type is defined by
1188+
# `on_true` and `on_false` and not the predicate `self`.
1189+
return on_true._pointwise(
1190+
lambda t, p, f: arith.select(p, t, f), self, on_false,
1191+
)
11881192

11891193
def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]):
11901194
"""Call a function for each value and index."""

0 commit comments

Comments
 (0)