Skip to content

Commit 19fcae1

Browse files
[Mosaic GPU] Add support for replicated warp_dim parsing and a dedicated test for parsing all canonical layouts.
PiperOrigin-RevId: 745015431
1 parent 51dbcd4 commit 19fcae1

File tree

3 files changed

+30
-10
lines changed

3 files changed

+30
-10
lines changed

jax/experimental/mosaic/gpu/layouts.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def is_strided_fragmented_layout(attr: ir.Attribute) -> bool:
9696

9797
_tiled_layout_attr_pattern = re.compile(
9898
r"^#mosaic_gpu.TiledLayout<\[(?P<tiling>.*)\],"
99-
r" warp_dim\s*=\s*(?P<warp_dim>[-\d]+),"
99+
r" warp_dim\s*=\s*(?P<warp_dim>.+),"
100100
r" lane_dims\s*=\s*\[(?P<lane_dims>.*)\],"
101101
r" vector_dim\s*=\s*(?P<vector_dim>[-\d]+)>$"
102102
)
@@ -107,22 +107,26 @@ def to_tiled_layout_attr(
107107
) -> ir.Attribute:
108108
"""Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout."""
109109

110-
def _lane_dim_str(d: int | fa.Replicated) -> str:
110+
def _int_or_replicated(d: int | fa.Replicated) -> str:
111111
if isinstance(d, fa.Replicated):
112112
return f"#mosaic_gpu.Replicated<times={d.times}>"
113113
return str(d)
114114

115115
tile_str = lambda tile: "[" + ", ".join(str(d) for d in tile) + "]"
116116
tiling = "[" + ", ".join(tile_str(tile) for tile in layout.tiling.tiles) + "]"
117-
lane_dims = "[" + ",".join(_lane_dim_str(d) for d in layout.lane_dims) + "]"
117+
lane_dims = (
118+
"[" + ",".join(_int_or_replicated(d) for d in layout.lane_dims) + "]"
119+
)
118120

119121
return ir.Attribute.parse(
120-
f"#mosaic_gpu.TiledLayout<{tiling}, warp_dim={layout.warp_dim},"
122+
f"#mosaic_gpu.TiledLayout<{tiling},"
123+
f" warp_dim={_int_or_replicated(layout.warp_dim)},"
121124
f" lane_dims={lane_dims}, vector_dim={layout.vector_dim}>"
122125
)
123126

124127

125128
_list_of_lists_delimiter = re.compile(r"\]\s*,\s*\[")
129+
_int_pattern = re.compile(r"^(?P<num>[-\d]+)(\s*:\s*\w+)?$")
126130
_replicated_pattern = re.compile(
127131
r"^#mosaic_gpu.Replicated<\s*times\s*=\s*(?P<times>\d+)\s*>\s*$"
128132
)
@@ -143,11 +147,14 @@ def from_tiled_layout_attr(
143147
f"Expected a #mosaic_gpu.TiledLayout attribute, got {attr}"
144148
)
145149

146-
def _lane_dim(lane_dim_str: str) -> int | fa.Replicated:
147-
match = _replicated_pattern.fullmatch(lane_dim_str)
150+
def _int_or_replicated(replicated_dim: str) -> int | fa.Replicated:
151+
match = _replicated_pattern.fullmatch(replicated_dim)
148152
if match:
149153
return fa.Replicated(int(match.group("times")))
150-
return int(lane_dim_str)
154+
match = _int_pattern.fullmatch(replicated_dim)
155+
if match:
156+
return int(match.group("num"))
157+
raise ValueError(f"Unexpected format for replicated dim {replicated_dim}")
151158

152159
tiling_str = match.group("tiling")
153160
tile_strings = []
@@ -156,9 +163,10 @@ def _lane_dim(lane_dim_str: str) -> int | fa.Replicated:
156163
tiles = tuple(tuple(map(int, ts.split(","))) for ts in tile_strings)
157164
return fa.TiledLayout(
158165
tiling=fa.Tiling(tiles),
159-
warp_dim=int(match.group("warp_dim")),
166+
warp_dim=_int_or_replicated(match.group("warp_dim")),
160167
lane_dims=tuple(
161-
_lane_dim(s) for s in match.group("lane_dims").split(",")
168+
_int_or_replicated(s.strip())
169+
for s in match.group("lane_dims").split(",")
162170
),
163171
vector_dim=int(match.group("vector_dim")),
164172
)

jaxlib/mosaic/dialect/gpu/mosaic_gpu.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def MosaicGPU_TiledLayout : AttrDef<MosaicGPU_Dialect, "TiledLayout", []> {
161161

162162
let parameters = (ins
163163
"::mlir::ArrayAttr":$tiling,
164-
"int":$warp_dim,
164+
"::mlir::Attribute":$warp_dim,
165165
"::mlir::ArrayAttr":$lane_dims,
166166
"int":$vector_dim
167167
);

tests/mosaic/gpu_dialect_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,18 @@ def test_wgmma_b_n_dim_not_equal_to_acc_n_dim(self):
593593
):
594594
self.module.operation.verify()
595595

596+
def test_tiled_layout_attr_parsing(self):
597+
with ir.InsertionPoint(self.module.body):
598+
for layout in (
599+
mgpu.WGMMA_LAYOUT,
600+
mgpu.WGMMA_ROW_LAYOUT,
601+
mgpu.WGMMA_COL_LAYOUT,
602+
mgpu.WGMMA_TRANSPOSED_LAYOUT,
603+
):
604+
attr = layouts.to_tiled_layout_attr(layout)
605+
parsed_layout = layouts.from_tiled_layout_attr(attr)
606+
self.assertEqual(layout, parsed_layout)
607+
596608

597609
class DialectLoweringTest(MosaicGpuTest):
598610

0 commit comments

Comments
 (0)