@@ -201,6 +201,7 @@ class ModuleContext:
201201 ]
202202 name_stack : source_info_util .NameStack
203203 traceback_caches : mlir .TracebackCaches
204+ squashed_dims : tuple [int , ...]
204205
205206 def reserve_barrier (self , barrier : mgpu .Barrier ) -> mgpu .BarrierRef :
206207 """Reserves a barrier.
@@ -403,12 +404,15 @@ def lower_jaxpr_to_module(
403404 parallel_grid = [
404405 d for i , d in enumerate (logical_grid ) if i not in sequential_axes
405406 ]
406- if len (parallel_grid ) < 3 :
407+ if len (parallel_grid ) <= 3 :
408+ squashed_dims = ()
407409 parallel_grid += (1 ,) * (3 - len (parallel_grid ))
408- elif len (parallel_grid ) > 3 :
409- raise NotImplementedError (
410- "Only <=3D grids are supported in Mosaic GPU lowering."
411- )
410+ else :
411+ # If we have >3 parallel dimensions, we merge all leading dimensions
412+ # into the first (Dimension.x) CUDA grid dimension.
413+ squashed_dims = parallel_grid [:- 2 ]
414+ parallel_grid = [math .prod (parallel_grid [:- 2 ]), * parallel_grid [- 2 :]]
415+
412416 if sequential_axes :
413417 # TODO(slebedev): Support multiple sequential axes.
414418 if len (sequential_axes ) > 1 :
@@ -496,7 +500,7 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
496500
497501 parallel_count = it .count ()
498502 program_ids_template = [
499- _program_id (next (parallel_count ))
503+ _program_id (next (parallel_count ), squashed_dims = squashed_dims )
500504 if axis not in sequential_axes
501505 else None
502506 for axis in range (len (logical_grid ))
@@ -520,6 +524,7 @@ def make_program_ids(step: ir.Value):
520524 runtime_barriers = grouped_barriers ,
521525 name_stack = source_info_util .NameStack (),
522526 traceback_caches = mlir .TracebackCaches (),
527+ squashed_dims = squashed_dims ,
523528 )
524529 del runtime_smem , grouped_barriers , runtime_barriers
525530
@@ -911,12 +916,42 @@ def _program_id_lowering_rule(ctx: LoweringRuleContext, axis):
911916 raise NotImplementedError ("pl.program_id() is not supported in this context" )
912917 return ctx .module_ctx .program_ids [axis ]
913918
914-
915- def _program_id (axis : int ) -> ir .Value :
916- return arith_dialect .index_cast (
917- ir .IntegerType .get_signless (32 ),
918- gpu_dialect .block_id (gpu_dialect .Dimension (axis )),
919- )
919+ def _unravel_program_id (
920+ block_id : ir .Value ,
921+ axis : int ,
922+ dimensions : tuple [int , ...],
923+ row_major : bool = False
924+ ) -> ir .Value :
925+ """Computes the program ID for axes compressed into one block dimension."""
926+ if row_major :
927+ div_value = math .prod (dimensions [axis + 1 :])
928+ else :
929+ div_value = math .prod (dimensions [:axis ])
930+ div_value = _as_index (_i32_constant (div_value ))
931+ pid = arith_dialect .divui (block_id , div_value )
932+ axis_size = _as_index (_i32_constant (dimensions [axis ]))
933+ pid = arith_dialect .remui (pid , axis_size )
934+ return arith_dialect .index_cast (ir .IntegerType .get_signless (32 ), pid )
935+
936+
937+ def _program_id (parallel_axis : int , squashed_dims : tuple [int , ...]) -> ir .Value :
938+ if squashed_dims :
939+ if parallel_axis < len (squashed_dims ):
940+ # All squashed dimensions are mapped to Dimension.x.
941+ block_id = gpu_dialect .block_id (gpu_dialect .Dimension .x )
942+ return _unravel_program_id (block_id , parallel_axis , squashed_dims )
943+ else :
944+ # Handle unsquashed axes.
945+ return arith_dialect .index_cast (
946+ ir .IntegerType .get_signless (32 ),
947+ gpu_dialect .block_id (gpu_dialect .Dimension (
948+ parallel_axis - len (squashed_dims ) + 1 )),
949+ )
950+ else :
951+ return arith_dialect .index_cast (
952+ ir .IntegerType .get_signless (32 ),
953+ gpu_dialect .block_id (gpu_dialect .Dimension (parallel_axis )),
954+ )
920955
921956
922957@register_lowering_rule (primitives .num_programs_p )
@@ -1244,16 +1279,44 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
12441279
12451280@register_lowering_rule (lax .axis_index_p )
12461281def _axis_index_rule (ctx : LoweringRuleContext , * , axis_name : Hashable ):
1282+ i32 = ir .IntegerType .get_signless (32 )
12471283 grid_names = ctx .module_ctx .grid_mapping .grid_names
1284+ squashed_dims = ctx .module_ctx .squashed_dims
1285+ if squashed_dims :
1286+ unsquashed_names = grid_names [- 3 :]
1287+ squashed_names = grid_names [:- 3 ]
1288+ else :
1289+ # These are unused but initialized for type checkers.
1290+ unsquashed_names = ()
1291+ squashed_names = ()
12481292 if grid_names and axis_name in grid_names :
12491293 if axis_name == grid_names [- 1 ]:
12501294 return mgpu .warpgroup_idx (sync = True )
12511295 else :
1252- idx = grid_names .index (axis_name )
1253- return arith_dialect .index_cast (
1254- ir .IntegerType .get_signless (32 ),
1255- gpu_dialect .block_id (gpu_dialect .Dimension (idx )),
1256- )
1296+ if squashed_dims :
1297+ if axis_name in unsquashed_names :
1298+ # We add 1 to the index because the first dimension is the
1299+ # squashed dimension.
1300+ # e.g. for the grid (a, b, c, d, wg)
1301+ # squashed = (a, b) Mapped to Dimension.x (0)
1302+ # unsquashed = (c, d) Mapped to Dimension.y (1) and Dimension.z (2)
1303+ idx = unsquashed_names .index (axis_name ) + 1
1304+ return arith_dialect .index_cast (
1305+ i32 ,
1306+ gpu_dialect .block_id (gpu_dialect .Dimension (idx )),
1307+ )
1308+ elif axis_name in squashed_names :
1309+ # All squashed dimensions are mapped to Dimension.x.
1310+ block_id = gpu_dialect .block_id (gpu_dialect .Dimension .x )
1311+ axis = squashed_names .index (axis_name )
1312+ return _unravel_program_id (block_id , axis , squashed_dims )
1313+ else :
1314+ if axis_name in grid_names :
1315+ idx = grid_names .index (axis_name )
1316+ return arith_dialect .index_cast (
1317+ i32 ,
1318+ gpu_dialect .block_id (gpu_dialect .Dimension (idx )),
1319+ )
12571320 raise ValueError (
12581321 "Named axes can only refer to GPUMesh axes in Mosaic GPU kernels"
12591322 )
@@ -1669,10 +1732,14 @@ def _ir_constant(v: object, t: ir.Type) -> ir.Value:
16691732
16701733
16711734def _i32_constant (v : int ) -> ir .Value :
1735+ if v < jnp .iinfo (jnp .int32 ).min or v > jnp .iinfo (jnp .int32 ).max :
1736+ raise ValueError (f"Integer constant out of range for i32: { v } " )
16721737 return arith_dialect .constant (ir .IntegerType .get_signless (32 ), v )
16731738
16741739
16751740def _i64_constant (v : int ) -> ir .Value :
1741+ if v < jnp .iinfo (jnp .int64 ).min or v > jnp .iinfo (jnp .int64 ).max :
1742+ raise ValueError (f"Integer constant out of range for i64: { v } " )
16761743 return arith_dialect .constant (ir .IntegerType .get_signless (64 ), v )
16771744
16781745
0 commit comments