@@ -96,7 +96,7 @@ def is_strided_fragmented_layout(attr: ir.Attribute) -> bool:
9696
9797_tiled_layout_attr_pattern = re .compile (
9898 r"^#mosaic_gpu.TiledLayout<\[(?P<tiling>.*)\],"
99- r" warp_dim\s*=\s*(?P<warp_dim>[-\d] +),"
99+ r" warp_dim\s*=\s*(?P<warp_dim>. +),"
100100 r" lane_dims\s*=\s*\[(?P<lane_dims>.*)\],"
101101 r" vector_dim\s*=\s*(?P<vector_dim>[-\d]+)>$"
102102)
@@ -107,22 +107,26 @@ def to_tiled_layout_attr(
107107) -> ir .Attribute :
108108 """Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout."""
109109
110- def _lane_dim_str (d : int | fa .Replicated ) -> str :
110+ def _int_or_replicated (d : int | fa .Replicated ) -> str :
111111 if isinstance (d , fa .Replicated ):
112112 return f"#mosaic_gpu.Replicated<times={ d .times } >"
113113 return str (d )
114114
115115 tile_str = lambda tile : "[" + ", " .join (str (d ) for d in tile ) + "]"
116116 tiling = "[" + ", " .join (tile_str (tile ) for tile in layout .tiling .tiles ) + "]"
117- lane_dims = "[" + "," .join (_lane_dim_str (d ) for d in layout .lane_dims ) + "]"
117+ lane_dims = (
118+ "[" + "," .join (_int_or_replicated (d ) for d in layout .lane_dims ) + "]"
119+ )
118120
119121 return ir .Attribute .parse (
120- f"#mosaic_gpu.TiledLayout<{ tiling } , warp_dim={ layout .warp_dim } ,"
122+ f"#mosaic_gpu.TiledLayout<{ tiling } ,"
123+ f" warp_dim={ _int_or_replicated (layout .warp_dim )} ,"
121124 f" lane_dims={ lane_dims } , vector_dim={ layout .vector_dim } >"
122125 )
123126
124127
125128_list_of_lists_delimiter = re .compile (r"\]\s*,\s*\[" )
129+ _int_pattern = re .compile (r"^(?P<num>[-\d]+)(\s*:\s*\w+)?$" )
126130_replicated_pattern = re .compile (
127131 r"^#mosaic_gpu.Replicated<\s*times\s*=\s*(?P<times>\d+)\s*>\s*$"
128132)
@@ -143,11 +147,14 @@ def from_tiled_layout_attr(
143147 f"Expected a #mosaic_gpu.TiledLayout attribute, got { attr } "
144148 )
145149
146- def _lane_dim ( lane_dim_str : str ) -> int | fa .Replicated :
147- match = _replicated_pattern .fullmatch (lane_dim_str )
150+ def _int_or_replicated ( replicated_dim : str ) -> int | fa .Replicated :
151+ match = _replicated_pattern .fullmatch (replicated_dim )
148152 if match :
149153 return fa .Replicated (int (match .group ("times" )))
150- return int (lane_dim_str )
154+ match = _int_pattern .fullmatch (replicated_dim )
155+ if match :
156+ return int (match .group ("num" ))
157+ raise ValueError (f"Unexpected format for replicated dim { replicated_dim } " )
151158
152159 tiling_str = match .group ("tiling" )
153160 tile_strings = []
@@ -156,9 +163,10 @@ def _lane_dim(lane_dim_str: str) -> int | fa.Replicated:
156163 tiles = tuple (tuple (map (int , ts .split ("," ))) for ts in tile_strings )
157164 return fa .TiledLayout (
158165 tiling = fa .Tiling (tiles ),
159- warp_dim = int (match .group ("warp_dim" )),
166+ warp_dim = _int_or_replicated (match .group ("warp_dim" )),
160167 lane_dims = tuple (
161- _lane_dim (s ) for s in match .group ("lane_dims" ).split ("," )
168+ _int_or_replicated (s .strip ())
169+ for s in match .group ("lane_dims" ).split ("," )
162170 ),
163171 vector_dim = int (match .group ("vector_dim" )),
164172 )
0 commit comments