5959from jax .experimental .mosaic .gpu import core as mgpu_core
6060from jax .experimental .mosaic .gpu import profiler as mgpu_profiler
6161from jax .experimental .mosaic .gpu import utils as mgpu_utils
62+ from jax .experimental .mosaic .gpu import tcgen05
6263import jax .numpy as jnp
6364import numpy as np
6465
@@ -100,6 +101,7 @@ def arrival_multiplier(self) -> int:
100101@dataclasses .dataclass (kw_only = True , frozen = True )
101102class Resources :
102103 smem_scratch_bytes : int = 0
104+ tmem_scratch_cols : int = 0
103105 barrier_counts : collections .Counter [mgpu .Barrier ] = dataclasses .field (
104106 default_factory = collections .Counter
105107 )
@@ -110,6 +112,12 @@ def __post_init__(self):
110112 "smem_scratch_bytes" ,
111113 _align_to (self .smem_scratch_bytes , _SMEM_ALIGNMENT ),
112114 )
115+ object .__setattr__ (
116+ self ,
117+ "tmem_scratch_cols" ,
118+ # TMEM must be allocated in 128x8 chunks.
119+ _align_to (self .tmem_scratch_cols , 8 ),
120+ )
113121
114122 @property
115123 def barriers (self ) -> Sequence [mgpu .Barrier ]:
@@ -122,6 +130,7 @@ def __add__(self, other: Resources) -> Resources:
122130 # we will allocate two barriers, even though one would be enough.
123131 return Resources (
124132 smem_scratch_bytes = self .smem_scratch_bytes + other .smem_scratch_bytes ,
133+ tmem_scratch_cols = self .tmem_scratch_cols + other .tmem_scratch_cols ,
125134 barrier_counts = self .barrier_counts + other .barrier_counts ,
126135 )
127136
@@ -130,6 +139,9 @@ def __or__(self, other: Resources) -> Resources:
130139 smem_scratch_bytes = max (
131140 self .smem_scratch_bytes , other .smem_scratch_bytes
132141 ),
142+ tmem_scratch_cols = max (
143+ self .tmem_scratch_cols , other .tmem_scratch_cols
144+ ),
133145 barrier_counts = self .barrier_counts | other .barrier_counts ,
134146 )
135147
@@ -218,10 +230,26 @@ def _run_scoped_resource_estimator(
218230 )
219231 ])
220232 )
221- else :
233+ elif aval .memory_space == gpu_core .TMEM :
234+ if aval .dtype .itemsize != 4 :
235+ raise ValueError ("TMEM only supports 32-bit types." )
236+ if len (aval .shape ) != 2 :
237+ raise ValueError ("TMEM allocations must be 2D." )
238+ if aval .shape [0 ] % tcgen05 .TMEM_ROWS != 0 :
239+ raise ValueError ("TMEM shape[0] must be a multiple of 128." )
240+ if aval .shape [1 ] % 8 != 0 :
241+ raise ValueError ("TMEM shape[1] must be a multiple of 8." )
242+ rs += Resources (tmem_scratch_cols = aval .shape [1 ])
243+ elif aval .memory_space == gpu_core .SMEM :
222244 rs += Resources (
223245 smem_scratch_bytes = math .prod (aval .shape ) * aval .dtype .itemsize
224246 )
247+ elif aval .memory_space == gpu_core .REGS :
248+ # Don't need to allocate anything.
249+ pass
250+ else :
251+ raise NotImplementedError (
252+ f"Unsupported memory space: { aval .memory_space } " )
225253 return rs + _estimate_resources (ctx , jaxpr )
226254
227255
@@ -267,6 +295,9 @@ class ModuleContext:
267295 single_wg_lane_predicate : ir .Value
268296 smem_requested_bytes : int
269297 smem_used_bytes : int
298+ tmem_requested_cols : int
299+ tmem_used_cols : int
300+ tmem_base_ptr : ir .Value
270301 runtime_barriers : MutableMapping [
271302 mgpu .Barrier , MutableSequence [mgpu .BarrierRef ]
272303 ]
@@ -286,6 +317,27 @@ def reserve_barrier(self, barrier: mgpu.Barrier) -> mgpu.BarrierRef:
286317 raise RuntimeError (f"Barrier { barrier } is already reserved" )
287318 return available .pop ()
288319
320+ @contextlib .contextmanager
321+ def alloc_tmem (
322+ self ,
323+ struct : jax .ShapeDtypeStruct ,
324+ layout : tcgen05 .TMEMLayout | None = None
325+ ) -> ir .Value :
326+ if self .tmem_used_cols > 0 :
327+ raise NotImplementedError (
328+ "Multiple TMEM allocations are not implemented." )
329+ if layout is None :
330+ layout = tcgen05 ._infer_tmem_layout (struct .shape , collective = False )
331+ cols_used = np .prod (struct .shape ) // tcgen05 .TMEM_ROWS
332+ self .tmem_used_cols += cols_used
333+ off = self .tmem_base_ptr
334+ tmem_ref = tcgen05 .TMEMRef (address = off ,
335+ shape = struct .shape ,
336+ dtype = mgpu_utils .dtype_to_ir_type (struct .dtype ),
337+ layout = layout )
338+ yield tmem_ref
339+ self .tmem_used_cols -= cols_used
340+
289341 # TODO(cperivol): Only return the shapes and figure out the sizes when freeing.
290342 @contextlib .contextmanager
291343 def scratch_view (
@@ -642,11 +694,15 @@ def lower_jaxpr_to_module(
642694 parallel_grid = (math .prod (grid [:- 2 ]), * grid [- 2 :])
643695
644696 def body (launch_ctx : mgpu .LaunchContext , * buffers : ir .Value ):
645- * buffers_gmem , (runtime_smem , runtime_barriers ) = buffers
697+ * buffers_gmem , (runtime_smem , runtime_barriers , runtime_tmem ) = buffers
646698
647699 grouped_barriers = collections .defaultdict (list )
648700 for barrier , barrier_ref in zip (rs .barriers , runtime_barriers ):
649701 grouped_barriers [barrier ].append (barrier_ref )
702+ if runtime_tmem is not None :
703+ tmem_cols = math .prod (runtime_tmem .shape ) // tcgen05 .TMEM_ROWS
704+ else :
705+ tmem_cols = 0
650706 module_ctx = ModuleContext (
651707 mlir .sanitize_name (debug_info .func_name ),
652708 axis_names ,
@@ -655,6 +711,9 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
655711 mgpu .single_thread_predicate (per_block = False ),
656712 smem_requested_bytes = math .prod (ir .MemRefType (runtime_smem .type ).shape ),
657713 smem_used_bytes = 0 ,
714+ tmem_requested_cols = tmem_cols ,
715+ tmem_used_cols = 0 ,
716+ tmem_base_ptr = runtime_tmem .address if runtime_tmem else None ,
658717 runtime_barriers = grouped_barriers ,
659718 name_stack = source_info_util .NameStack (),
660719 traceback_caches = mlir .TracebackCaches (),
@@ -671,6 +730,18 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
671730 smem_scratch_bytes = params .get ("smem_scratch_bytes" )
672731 if smem_scratch_bytes is None :
673732 smem_scratch_bytes = rs .smem_scratch_bytes
733+ tmem_scratch_cols = rs .tmem_scratch_cols
734+
735+ scratch_buffers = [
736+ jax .ShapeDtypeStruct (shape = [smem_scratch_bytes ], dtype = np .int8 ),
737+ rs .barriers ,
738+ ]
739+ if tmem_scratch_cols > 0 :
740+ scratch_buffers .append (
741+ mgpu .TMEM (shape = [tcgen05 .TMEM_ROWS , tmem_scratch_cols ], dtype = np .int32 ),
742+ )
743+ else :
744+ scratch_buffers .append (None )
674745
675746 prof_ctx = prof_spec = None
676747 if prof_space := params .get ("profile_space" , 0 ):
@@ -685,10 +756,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
685756 block = block ,
686757 in_shapes = in_shapes ,
687758 out_shape = out_shapes ,
688- smem_scratch_shape = (
689- jax .ShapeDtypeStruct (shape = [smem_scratch_bytes ], dtype = np .int8 ),
690- rs .barriers ,
691- ),
759+ smem_scratch_shape = scratch_buffers ,
692760 module_name = mlir .sanitize_name (debug_info .func_name ),
693761 prof_spec = prof_spec ,
694762 )
@@ -990,14 +1058,26 @@ def _ndindexer_indices(indexer: indexing.NDIndexer) -> tuple[gpu_core.Index, ...
9901058
9911059
9921060@register_lowering_rule (sp .get_p , mgpu .ThreadSemantics .Lane )
993- def _get_lowering_rule (ctx : LoweringRuleContext , x_smem , * leaves , tree ):
994- if not isinstance (x_smem , ir .Value ) and ir .MemRefType .isinstance (x_smem ):
995- raise TypeError (f"Can only load from references (got { x_smem } )." )
1061+ def _get_lowering_rule (ctx : LoweringRuleContext , x_ref , * leaves , tree ):
1062+ if isinstance (x_ref , tcgen05 .TMEMRef ):
1063+ transforms = jax .tree .unflatten (tree , leaves )
1064+ if len (transforms ) != 1 or not isinstance (
1065+ transforms [0 ], indexing .NDIndexer ):
1066+ raise NotImplementedError (
1067+ "Only a single indexing transform is supported for TMEM refs." )
1068+ indexer = cast (indexing .NDIndexer , transforms [0 ])
1069+ if not gpu_core .is_trivial_index (indexer .indices , x_ref .shape ):
1070+ raise NotImplementedError (
1071+ "Only trivial indexing is supported for TMEM refs." )
1072+ return x_ref [:]
1073+
1074+ if not isinstance (x_ref , ir .Value ) and ir .MemRefType .isinstance (x_ref ):
1075+ raise TypeError (f"Can only load from references (got { x_ref } )." )
9961076
9971077 x_aval = ctx .avals_in [0 ]
9981078
9991079 transforms = jax .tree .unflatten (tree , leaves )
1000- x_smem , transforms = _handle_reshaping (x_smem , transforms )
1080+ x_smem , transforms = _handle_reshaping (x_ref , transforms )
10011081 x_smem , transforms = _handle_indexing (x_smem , transforms )
10021082
10031083 match transforms :
@@ -1784,6 +1864,14 @@ def _run_scoped_lowering_rule(
17841864 )
17851865 input_refs .append (input_ref )
17861866 should_discharge .append (False )
1867+ elif aval .memory_space == gpu_core .TMEM :
1868+ input_ref = alloc_stack .enter_context (
1869+ ctx .module_ctx .alloc_tmem (
1870+ jax .ShapeDtypeStruct (shape = aval .shape , dtype = aval .dtype ),
1871+ )
1872+ )
1873+ input_refs .append (input_ref )
1874+ should_discharge .append (False )
17871875 else :
17881876 raise ValueError (f"Can't convert to ref: { aval } " )
17891877
0 commit comments