Skip to content

Commit d0b71fa

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add preliminary TMEM allocation support for Pallas/Mosaic GPU.
PiperOrigin-RevId: 738932990
1 parent 80784a5 commit d0b71fa

File tree

5 files changed

+134
-10
lines changed

5 files changed

+134
-10
lines changed

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ class GPUMemorySpace(enum.Enum):
101101
GMEM = "gmem"
102102
#: Shared memory.
103103
SMEM = "smem"
104+
#: Tensor memory.
105+
TMEM = "tmem"
104106
#: Registers.
105107
REGS = "regs"
106108

@@ -452,6 +454,7 @@ def to_block_mapping(
452454

453455
GMEM = GPUMemorySpace.GMEM
454456
SMEM = GPUMemorySpace.SMEM
457+
TMEM = GPUMemorySpace.TMEM
455458
REGS = GPUMemorySpace.REGS
456459

457460

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 98 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from jax.experimental.mosaic.gpu import core as mgpu_core
6060
from jax.experimental.mosaic.gpu import profiler as mgpu_profiler
6161
from jax.experimental.mosaic.gpu import utils as mgpu_utils
62+
from jax.experimental.mosaic.gpu import tcgen05
6263
import jax.numpy as jnp
6364
import numpy as np
6465

@@ -100,6 +101,7 @@ def arrival_multiplier(self) -> int:
100101
@dataclasses.dataclass(kw_only=True, frozen=True)
101102
class Resources:
102103
smem_scratch_bytes: int = 0
104+
tmem_scratch_cols: int = 0
103105
barrier_counts: collections.Counter[mgpu.Barrier] = dataclasses.field(
104106
default_factory=collections.Counter
105107
)
@@ -110,6 +112,12 @@ def __post_init__(self):
110112
"smem_scratch_bytes",
111113
_align_to(self.smem_scratch_bytes, _SMEM_ALIGNMENT),
112114
)
115+
object.__setattr__(
116+
self,
117+
"tmem_scratch_cols",
118+
# TMEM must be allocated in 128x8 chunks.
119+
_align_to(self.tmem_scratch_cols, 8),
120+
)
113121

114122
@property
115123
def barriers(self) -> Sequence[mgpu.Barrier]:
@@ -122,6 +130,7 @@ def __add__(self, other: Resources) -> Resources:
122130
# we will allocate two barriers, even though one would be enough.
123131
return Resources(
124132
smem_scratch_bytes=self.smem_scratch_bytes + other.smem_scratch_bytes,
133+
tmem_scratch_cols=self.tmem_scratch_cols + other.tmem_scratch_cols,
125134
barrier_counts=self.barrier_counts + other.barrier_counts,
126135
)
127136

@@ -130,6 +139,9 @@ def __or__(self, other: Resources) -> Resources:
130139
smem_scratch_bytes=max(
131140
self.smem_scratch_bytes, other.smem_scratch_bytes
132141
),
142+
tmem_scratch_cols=max(
143+
self.tmem_scratch_cols, other.tmem_scratch_cols
144+
),
133145
barrier_counts=self.barrier_counts | other.barrier_counts,
134146
)
135147

@@ -218,10 +230,26 @@ def _run_scoped_resource_estimator(
218230
)
219231
])
220232
)
221-
else:
233+
elif aval.memory_space == gpu_core.TMEM:
234+
if aval.dtype.itemsize != 4:
235+
raise ValueError("TMEM only supports 32-bit types.")
236+
if len(aval.shape) != 2:
237+
raise ValueError("TMEM allocations must be 2D.")
238+
if aval.shape[0] % tcgen05.TMEM_ROWS != 0:
239+
raise ValueError("TMEM shape[0] must be a multiple of 128.")
240+
if aval.shape[1] % 8 != 0:
241+
raise ValueError("TMEM shape[1] must be a multiple of 8.")
242+
rs += Resources(tmem_scratch_cols=aval.shape[1])
243+
elif aval.memory_space == gpu_core.SMEM:
222244
rs += Resources(
223245
smem_scratch_bytes=math.prod(aval.shape) * aval.dtype.itemsize
224246
)
247+
elif aval.memory_space == gpu_core.REGS:
248+
# Don't need to allocate anything.
249+
pass
250+
else:
251+
raise NotImplementedError(
252+
f"Unsupported memory space: {aval.memory_space}")
225253
return rs + _estimate_resources(ctx, jaxpr)
226254

227255

@@ -267,6 +295,9 @@ class ModuleContext:
267295
single_wg_lane_predicate: ir.Value
268296
smem_requested_bytes: int
269297
smem_used_bytes: int
298+
tmem_requested_cols: int
299+
tmem_used_cols: int
300+
tmem_base_ptr: ir.Value
270301
runtime_barriers: MutableMapping[
271302
mgpu.Barrier, MutableSequence[mgpu.BarrierRef]
272303
]
@@ -286,6 +317,27 @@ def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef:
286317
raise RuntimeError(f"Barrier {barrier} is already reserved")
287318
return available.pop()
288319

320+
@contextlib.contextmanager
321+
def alloc_tmem(
322+
self,
323+
struct: jax.ShapeDtypeStruct,
324+
layout: tcgen05.TMEMLayout | None = None
325+
) -> ir.Value:
326+
if self.tmem_used_cols > 0:
327+
raise NotImplementedError(
328+
"Multiple TMEM allocations are not implemented.")
329+
if layout is None:
330+
layout = tcgen05._infer_tmem_layout(struct.shape, collective=False)
331+
cols_used = np.prod(struct.shape) // tcgen05.TMEM_ROWS
332+
self.tmem_used_cols += cols_used
333+
off = self.tmem_base_ptr
334+
tmem_ref = tcgen05.TMEMRef(address=off,
335+
shape=struct.shape,
336+
dtype=mgpu_utils.dtype_to_ir_type(struct.dtype),
337+
layout=layout)
338+
yield tmem_ref
339+
self.tmem_used_cols -= cols_used
340+
289341
# TODO(cperivol): Only return the shapes and figure out the sizes when freeing.
290342
@contextlib.contextmanager
291343
def scratch_view(
@@ -642,11 +694,15 @@ def lower_jaxpr_to_module(
642694
parallel_grid = (math.prod(grid[:-2]), *grid[-2:])
643695

644696
def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
645-
*buffers_gmem, (runtime_smem, runtime_barriers) = buffers
697+
*buffers_gmem, (runtime_smem, runtime_barriers, runtime_tmem) = buffers
646698

647699
grouped_barriers = collections.defaultdict(list)
648700
for barrier, barrier_ref in zip(rs.barriers, runtime_barriers):
649701
grouped_barriers[barrier].append(barrier_ref)
702+
if runtime_tmem is not None:
703+
tmem_cols = math.prod(runtime_tmem.shape) // tcgen05.TMEM_ROWS
704+
else:
705+
tmem_cols = 0
650706
module_ctx = ModuleContext(
651707
mlir.sanitize_name(debug_info.func_name),
652708
axis_names,
@@ -655,6 +711,9 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
655711
mgpu.single_thread_predicate(per_block=False),
656712
smem_requested_bytes=math.prod(ir.MemRefType(runtime_smem.type).shape),
657713
smem_used_bytes=0,
714+
tmem_requested_cols=tmem_cols,
715+
tmem_used_cols=0,
716+
tmem_base_ptr=runtime_tmem.address if runtime_tmem else None,
658717
runtime_barriers=grouped_barriers,
659718
name_stack=source_info_util.NameStack(),
660719
traceback_caches=mlir.TracebackCaches(),
@@ -671,6 +730,18 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
671730
smem_scratch_bytes = params.get("smem_scratch_bytes")
672731
if smem_scratch_bytes is None:
673732
smem_scratch_bytes = rs.smem_scratch_bytes
733+
tmem_scratch_cols = rs.tmem_scratch_cols
734+
735+
scratch_buffers = [
736+
jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8),
737+
rs.barriers,
738+
]
739+
if tmem_scratch_cols > 0:
740+
scratch_buffers.append(
741+
mgpu.TMEM(shape=[tcgen05.TMEM_ROWS, tmem_scratch_cols], dtype=np.int32),
742+
)
743+
else:
744+
scratch_buffers.append(None)
674745

675746
prof_ctx = prof_spec = None
676747
if prof_space := params.get("profile_space", 0):
@@ -685,10 +756,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
685756
block=block,
686757
in_shapes=in_shapes,
687758
out_shape=out_shapes,
688-
smem_scratch_shape=(
689-
jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8),
690-
rs.barriers,
691-
),
759+
smem_scratch_shape=scratch_buffers,
692760
module_name=mlir.sanitize_name(debug_info.func_name),
693761
prof_spec=prof_spec,
694762
)
@@ -990,14 +1058,26 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...
9901058

9911059

9921060
@register_lowering_rule(sp.get_p, mgpu.ThreadSemantics.Lane)
993-
def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *leaves, tree):
994-
if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem):
995-
raise TypeError(f"Can only load from references (got {x_smem}).")
1061+
def _get_lowering_rule(ctx: LoweringRuleContext, x_ref, *leaves, tree):
1062+
if isinstance(x_ref, tcgen05.TMEMRef):
1063+
transforms = jax.tree.unflatten(tree, leaves)
1064+
if len(transforms) != 1 or not isinstance(
1065+
transforms[0], indexing.NDIndexer):
1066+
raise NotImplementedError(
1067+
"Only a single indexing transform is supported for TMEM refs.")
1068+
indexer = cast(indexing.NDIndexer, transforms[0])
1069+
if not gpu_core.is_trivial_index(indexer.indices, x_ref.shape):
1070+
raise NotImplementedError(
1071+
"Only trivial indexing is supported for TMEM refs.")
1072+
return x_ref[:]
1073+
1074+
if not isinstance(x_ref, ir.Value) and ir.MemRefType.isinstance(x_ref):
1075+
raise TypeError(f"Can only load from references (got {x_ref}).")
9961076

9971077
x_aval = ctx.avals_in[0]
9981078

9991079
transforms = jax.tree.unflatten(tree, leaves)
1000-
x_smem, transforms = _handle_reshaping(x_smem, transforms)
1080+
x_smem, transforms = _handle_reshaping(x_ref, transforms)
10011081
x_smem, transforms = _handle_indexing(x_smem, transforms)
10021082

10031083
match transforms:
@@ -1784,6 +1864,14 @@ def _run_scoped_lowering_rule(
17841864
)
17851865
input_refs.append(input_ref)
17861866
should_discharge.append(False)
1867+
elif aval.memory_space == gpu_core.TMEM:
1868+
input_ref = alloc_stack.enter_context(
1869+
ctx.module_ctx.alloc_tmem(
1870+
jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype),
1871+
)
1872+
)
1873+
input_refs.append(input_ref)
1874+
should_discharge.append(False)
17871875
else:
17881876
raise ValueError(f"Can't convert to ref: {aval}")
17891877

jax/experimental/mosaic/gpu/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,8 @@ def _smem_tree_size(smem_buffers: ShapeTree) -> int:
307307
raise NotImplementedError("Misaligned barrier allocation")
308308
size += num_barriers * utils.MBARRIER_BYTES
309309
case TMEM(_):
310+
# TODO(justinfu): This can trigger misaligned barrier allocations
311+
# if TMEM is requested before barriers b/c it's not divisible by 8.
310312
size += 4 # i32 takes up 4 bytes
311313
case _:
312314
size += _count_buffer_bytes(l)

jax/experimental/pallas/mosaic_gpu.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,5 @@
5151
GMEM = GPUMemorySpace.GMEM
5252
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.SMEM`.
5353
SMEM = GPUMemorySpace.SMEM
54+
#: Alias of :data:`jax.experimental.pallas.mosaic_gpu.GPUMemorySpace.TMEM`.
55+
TMEM = GPUMemorySpace.TMEM

tests/pallas/mosaic_gpu_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@ def setUp(self):
8383
super().setUp()
8484

8585

86+
class PallasSm100ATest(PallasTest, jtu.CudaArchSpecificTest):
87+
88+
def setUp(self):
89+
self.skip_unless_sm100a()
90+
super().setUp()
91+
92+
8693
class PallasCallTest(PallasTest):
8794

8895
@parameterized.product(
@@ -1531,6 +1538,28 @@ def scope(acc_ref):
15311538
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
15321539

15331540

1541+
class PallasCallSm100ATest(PallasSm100ATest):
1542+
1543+
def test_tmem_alloc(self):
1544+
mesh = plgpu.GPUMesh(num_threads=1, axis_names=("x"))
1545+
@pl.run_state
1546+
def inner(y_ref):
1547+
@pl.core_map(mesh)
1548+
def _():
1549+
def scope(tmem_ref, smem_ref):
1550+
# Issue a write so the TMEM load is not DCE'd.
1551+
smem_ref[...] = tmem_ref[...]
1552+
plgpu.commit_smem()
1553+
plgpu.copy_smem_to_gmem(smem_ref, y_ref)
1554+
plgpu.wait_smem_to_gmem(0)
1555+
pl.run_scoped(scope,
1556+
plgpu.TMEM((128, 128), jnp.float32),
1557+
plgpu.SMEM((128, 128), jnp.float32))
1558+
y_init = jnp.zeros((128, 128), np.float32)
1559+
# Test that this runs without errors.
1560+
jax.block_until_ready(inner(y_init))
1561+
1562+
15341563
class PipelineTest(PallasTest):
15351564

15361565
def test_pipeline_mode(self):

0 commit comments

Comments
 (0)