Skip to content

Commit aefe621

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Ported two pipelining optimizations to emit_pipeline
* Skip SMEM->GMEM copy if the destination buffer is being revisited * Skip SMEM->GMEM copy if the corresponding index map does not use grid indices PiperOrigin-RevId: 696448043
1 parent 8370082 commit aefe621

File tree

5 files changed

+195
-28
lines changed

5 files changed

+195
-28
lines changed

jax/_src/pallas/mosaic_gpu/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,10 @@ pytype_strict_library(
107107
":core",
108108
":primitives",
109109
"//jax",
110+
"//jax:core",
111+
"//jax:mosaic_gpu",
110112
"//jax:pallas",
113+
"//jax:partial_eval",
111114
"//jax:util",
112115
"//jax/_src/pallas",
113116
],

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def scratch_view(
264264
class LoweringRuleContext:
265265
module_ctx: ModuleContext
266266
launch_ctx: mgpu.LaunchContext
267+
predicate: ir.Value
267268
avals_in: Sequence[jax_core.ShapedArray]
268269
avals_out: Sequence[jax_core.ShapedArray]
269270

@@ -878,6 +879,7 @@ def write_env(var: jax_core.Var, val):
878879
rule_ctx = LoweringRuleContext(
879880
module_ctx,
880881
launch_ctx,
882+
predicate=mgpu.single_thread_predicate(per_block=False),
881883
avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars],
882884
avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars],
883885
)
@@ -1120,6 +1122,12 @@ def _convert_element_type_lowering_rule(
11201122
)
11211123

11221124

1125+
mosaic_lowering_rules.update({
1126+
lax.neg_p: lambda ctx, x: -x,
1127+
lax.not_p: lambda ctx, x: ~x,
1128+
})
1129+
1130+
11231131
def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl):
11241132
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
11251133
return impl(x, y)
@@ -1576,4 +1584,4 @@ def _as_index(v: object) -> ir.Value:
15761584
case mgpu.FragmentedArray(layout=mgpu.WGSplatFragLayout()):
15771585
return _as_index(v.registers.item())
15781586
case _:
1579-
raise ValueError(f"Unsupported index: {v}")
1587+
raise ValueError(f"Unsupported index: {v} of type {type(v)}")

jax/_src/pallas/mosaic_gpu/pipeline.py

Lines changed: 124 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
from collections.abc import Sequence
19+
from collections.abc import Callable, Sequence
2020
import dataclasses
2121
import functools
2222
import itertools as it
@@ -25,7 +25,10 @@
2525

2626
import jax
2727
from jax import lax
28+
from jax._src import core
29+
from jax._src import linear_util as lu
2830
from jax._src import util
31+
from jax._src.interpreters import partial_eval as pe
2932
from jax._src.pallas import core as pallas_core
3033
from jax._src.pallas.mosaic_gpu import core as gpu_core
3134
from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives
@@ -37,17 +40,19 @@
3740
zip = util.safe_zip
3841

3942

43+
@jax.tree_util.register_dataclass
4044
@dataclasses.dataclass(frozen=True)
4145
class BufferedRef:
42-
spec: pallas_core.BlockSpec
46+
spec: pallas_core.BlockSpec = dataclasses.field(metadata={"static": True})
47+
is_index_invariant: bool = dataclasses.field(metadata={"static": True})
4348
gmem_ref: pallas_core.AbstractMemoryRef
4449
smem_ref: pallas_core.AbstractMemoryRef # [num_slots, *spec.block_shape]
4550

46-
def compute_gmem_slice(self, grid_indices) -> tuple[Any, ...]:
51+
def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]:
4752
index_map = self.spec.index_map
4853
assert index_map is not None
4954
return tuple(
50-
pl.ds(idx * size, size)
55+
pl.Slice(idx * size, size) # type: ignore[arg-type]
5156
for idx, size in zip(
5257
index_map(*grid_indices), self.spec.block_shape # type: ignore[arg-type]
5358
)
@@ -61,16 +66,31 @@ def copy_in(self, slot, grid_indices, barrier_ref):
6166
barrier=barrier_ref.at[slot],
6267
)
6368

64-
def copy_out(self, slot, grid_indices):
69+
def copy_out(self, slot, grid_indices, predicate=None):
6570
gmem_slices = self.compute_gmem_slice(grid_indices)
6671
gpu_primitives.copy_smem_to_gmem(
67-
self.smem_ref.at[slot], self.gmem_ref.at[gmem_slices] # pytype: disable=unsupported-operands
72+
self.smem_ref.at[slot],
73+
self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands
74+
predicate=predicate,
6875
)
6976

7077

71-
jax.tree_util.register_dataclass(
72-
BufferedRef, data_fields=["gmem_ref", "smem_ref"], meta_fields=["spec"]
73-
)
78+
def _uses_arguments(
79+
index_map: Callable[..., Any], num_args: int
80+
) -> Sequence[bool]:
81+
jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic(
82+
lu.wrap_init(index_map), (core.ShapedArray((), jnp.int32),) * num_args
83+
)
84+
_, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))
85+
return used_inputs
86+
87+
88+
def _is_index_invariant(
89+
spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid
90+
) -> bool:
91+
index_map = spec.index_map
92+
assert index_map is not None
93+
return not any(_uses_arguments(index_map, len(grid)))
7494

7595

7696
def _inc_grid_by_1(
@@ -85,6 +105,25 @@ def _inc_grid_by_1(
85105
return tuple(reversed(next_indices))
86106

87107

108+
# ``pl.Slice`` uses a different pytree encoding, depending on whether the
109+
# start/size are static or dynamic. This leads to pytree structure mismatch
110+
# in the pipeline body. So, we define a different ``Slice`` class below.
111+
112+
113+
@dataclasses.dataclass(frozen=True)
114+
class _Slice:
115+
start: int | jax.Array
116+
size: int | jax.Array
117+
118+
def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
119+
return lax.bitwise_and(self.start == other.start, self.size == other.size)
120+
121+
122+
jax.tree_util.register_dataclass(
123+
_Slice, data_fields=["start", "size"], meta_fields=[]
124+
)
125+
126+
88127
def emit_pipeline(
89128
body,
90129
*,
@@ -102,6 +141,16 @@ def emit_pipeline(
102141
max_concurrent_steps = num_steps
103142

104143
def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
144+
for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)):
145+
if any(
146+
spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore
147+
for idx in range(1, len(grid) + 1)
148+
):
149+
raise NotImplementedError(
150+
f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block"
151+
f" shape {spec.block_shape}."
152+
)
153+
105154
in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)])
106155
in_smem_refs, out_smem_refs = util.split_list(
107156
map(
@@ -132,13 +181,18 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef):
132181
def scoped_pipeline(
133182
*, in_gmem_refs, out_gmem_refs, in_smem_refs, out_smem_refs, barrier_ref
134183
):
135-
136-
in_brefs: Sequence[BufferedRef] = map(
137-
BufferedRef, in_specs, in_gmem_refs, in_smem_refs
138-
)
139-
out_brefs: Sequence[BufferedRef] = map(
140-
BufferedRef, out_specs, out_gmem_refs, out_smem_refs
141-
)
184+
in_brefs: Sequence[BufferedRef] = [
185+
BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref)
186+
for spec, gmem_ref, smem_ref in zip(
187+
in_specs, in_gmem_refs, in_smem_refs
188+
)
189+
]
190+
out_brefs: Sequence[BufferedRef] = [
191+
BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref)
192+
for spec, gmem_ref, smem_ref in zip(
193+
out_specs, out_gmem_refs, out_smem_refs
194+
)
195+
]
142196

143197
for step, indices in enumerate(
144198
it.islice(it.product(*map(range, grid)), max_concurrent_steps)
@@ -147,10 +201,11 @@ def scoped_pipeline(
147201

148202
def loop_body(step, carry):
149203
slot = step % max_concurrent_steps
150-
indices, fetch_indices = carry
204+
indices, fetch_indices, last_store_slices = carry
151205

152-
# Wait for the current GMEM->SMEM copy to complete.
153-
gpu_primitives.barrier_wait(barrier_ref.at[slot])
206+
if in_specs:
207+
# Wait for the current GMEM->SMEM copy to complete.
208+
gpu_primitives.barrier_wait(barrier_ref.at[slot])
154209
# Wait for the previous output SMEM->GMEM copy to complete.
155210
gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1)
156211

@@ -159,9 +214,34 @@ def loop_body(step, carry):
159214
*(bref.smem_ref.at[slot] for bref in it.chain(in_brefs, out_brefs))
160215
)
161216

217+
if not all(bref.is_index_invariant for bref in out_brefs):
218+
gpu_primitives.commit_smem()
219+
162220
# Copy the output from SMEM to GMEM.
163-
gpu_primitives.commit_smem()
164-
map(lambda bref: bref.copy_out(slot, indices), out_brefs)
221+
new_store_slices = last_store_slices[:]
222+
for idx, bref in enumerate(out_brefs):
223+
if bref.is_index_invariant:
224+
assert last_store_slices[idx] is None
225+
continue
226+
assert last_store_slices[idx] is not None
227+
new_store_slices[idx] = tuple(
228+
_Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices)
229+
)
230+
are_same_slices = map(
231+
lambda old, new: old == new,
232+
last_store_slices[idx],
233+
new_store_slices[idx],
234+
)
235+
slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices)
236+
is_last_step = step == num_steps - 1
237+
# TODO(apaszke,slebedev): This still diverges significantly from the
238+
# TPU semantics in that it will move on to the next SMEM output slice
239+
# even if it's not storing the previous one.
240+
bref.copy_out(
241+
slot,
242+
indices,
243+
predicate=lax.bitwise_or(slices_changed, is_last_step),
244+
)
165245

166246
fetch_step = step + max_concurrent_steps
167247
fetch_slot = slot # (x + y) % y == x % y
@@ -174,13 +254,34 @@ def loop_body(step, carry):
174254
lambda: [None] * len(in_brefs),
175255
)
176256

177-
return _inc_grid_by_1(indices, grid), _inc_grid_by_1(fetch_indices, grid)
257+
return (
258+
_inc_grid_by_1(indices, grid),
259+
_inc_grid_by_1(fetch_indices, grid),
260+
new_store_slices,
261+
)
178262

179263
indices = (jnp.asarray(0, dtype=lax.dtype(0)),) * len(grid)
180264
fetch_indices = indices
181265
for _ in range(max_concurrent_steps):
182266
fetch_indices = _inc_grid_by_1(fetch_indices, grid)
183-
lax.fori_loop(0, num_steps, loop_body, (indices, fetch_indices))
267+
last_store_slices = [
268+
None
269+
if bref.is_index_invariant
270+
else (_Slice(-1, -1),) * len(bref.spec.block_shape)
271+
for bref in out_brefs
272+
]
273+
last_indices, _, _ = lax.fori_loop(
274+
0, num_steps, loop_body, (indices, fetch_indices, last_store_slices)
275+
)
276+
277+
# Outputs invariant to the sequential axis are never written from inside the
278+
# loop. This is the only place where we store them.
279+
if all(bref.is_index_invariant for bref in out_brefs):
280+
gpu_primitives.commit_smem()
281+
last_slot = (num_steps - 1) % max_concurrent_steps
282+
for bref in out_brefs:
283+
if bref.is_index_invariant:
284+
bref.copy_out(last_slot, last_indices, predicate=None)
184285

185286
# Finalize the pipeline.
186287
gpu_primitives.wait_smem_to_gmem(0)

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from jax._src import tree_util
2727
from jax._src import util
2828
from jax._src.lib.mlir import ir
29+
from jax._src.lib.mlir.dialects import arith as arith_dialect
2930
from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect
3031
from jax._src.pallas import core as pallas_core
3132
from jax._src.pallas.mosaic_gpu import core as gpu_core
@@ -34,6 +35,7 @@
3435
from jax._src.state import indexing
3536
from jax._src.state import primitives as state_primitives
3637
import jax.experimental.mosaic.gpu as mgpu
38+
import jax.numpy as jnp
3739

3840

3941
WARPGROUP_SIZE = 128
@@ -54,19 +56,31 @@ def _copy_smem_to_gmem_lowering(
5456
ctx: lowering.LoweringRuleContext,
5557
src,
5658
dst,
57-
*flat_transforms,
59+
*flat_args,
5860
src_transforms_treedef,
5961
dst_transforms_treedef,
62+
has_user_predicate,
6063
):
64+
predicate = ctx.predicate
65+
if has_user_predicate:
66+
flat_args, user_predicate = flat_args[:-1], flat_args[-1]
67+
predicate = arith_dialect.andi(
68+
predicate, lowering._ensure_ir_value(user_predicate, jnp.bool)
69+
)
6170
flat_src_transforms, flat_dst_transforms = util.split_list(
62-
flat_transforms,
71+
flat_args,
6372
[src_transforms_treedef.num_leaves],
6473
)
6574
src_transforms = src_transforms_treedef.unflatten(flat_src_transforms)
6675
dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms)
6776
src, src_transforms = lowering._handle_indexing(src, src_transforms)
6877
copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms)
69-
ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params)
78+
ctx.launch_ctx.async_copy(
79+
src_ref=src,
80+
dst_ref=dst,
81+
predicate=predicate,
82+
**copy_params,
83+
)
7084
return ()
7185

7286

@@ -98,10 +112,18 @@ def _extract_smem_copy_params(transforms):
98112

99113

100114
def copy_smem_to_gmem(
101-
src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef
115+
src: pallas_core.AbstractMemoryRef,
116+
dst: pallas_core.AbstractMemoryRef,
117+
predicate: jax.Array | None = None,
102118
) -> None:
103119
"""Asynchronously copies a SMEM reference to a GMEM reference.
104120
121+
Args:
122+
src: The SMEM reference to copy from.
123+
dst: The GMEM reference to copy to.
124+
predicate: A boolean indicating whether the copy should be performed. If
125+
``None``, the copy is always performed.
126+
105127
See also:
106128
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
107129
:func:`jax.experimental.mosaic.gpu.commit_smem`
@@ -127,8 +149,10 @@ def copy_smem_to_gmem(
127149
dst,
128150
*flat_src_transforms,
129151
*flat_dst_transforms,
152+
*[] if predicate is None else [predicate],
130153
src_transforms_treedef=src_transforms_treedef,
131154
dst_transforms_treedef=dst_transforms_treedef,
155+
has_user_predicate=predicate is not None,
132156
)
133157
return None
134158

0 commit comments

Comments
 (0)