|
25 | 25 | from jax._src.lib.mlir import ir |
26 | 26 | from jax._src.lib.mlir.dialects import arith |
27 | 27 | from jax._src.lib.mlir.dialects import func |
| 28 | +from jax._src.lib.mlir.dialects import vector |
28 | 29 | import jax.experimental.mosaic.gpu as mgpu |
| 30 | +from jax.experimental.mosaic.gpu import fragmented_array as fa |
29 | 31 | from jax.experimental.mosaic.gpu import inference_utils |
| 32 | +from jax.experimental.mosaic.gpu import layouts as layouts_lib |
30 | 33 | import numpy as np |
31 | 34 |
|
32 | 35 |
|
@@ -162,6 +165,187 @@ def body(gmem_ref, smem_ref): |
162 | 165 | ) |
163 | 166 | self.assertEmpty(inference_utils.out_transforms(async_store_op)) |
164 | 167 |
|
| 168 | + def test_infer_transforms_for_vector_load_op_derives_from_destination(self): |
| 169 | + vector_load_op = None |
| 170 | + shape = (64, 64) |
| 171 | + elt_ty = ir.BF16Type.get() |
| 172 | + |
| 173 | + def body(smem_ref): |
| 174 | + nonlocal vector_load_op |
| 175 | + zero = arith.constant(ir.IntegerType.get_signless(32), 0) |
| 176 | + vector_load_op = vector.LoadOp( |
| 177 | + ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape) |
| 178 | + ) |
| 179 | + |
| 180 | + with ir.InsertionPoint(self.module.body): |
| 181 | + smem = ir.Attribute.parse("#gpu.address_space<workgroup>") |
| 182 | + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) |
| 183 | + func.FuncOp.from_py_func(smem_ty)(body) |
| 184 | + |
| 185 | + vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get( |
| 186 | + [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] |
| 187 | + ) |
| 188 | + |
| 189 | + mgpu.infer_transforms(self.module) |
| 190 | + |
| 191 | + expected_transforms = ir.ArrayAttr.get([ |
| 192 | + mgpu.dialect.TileTransformAttr.get((8, 64)), |
| 193 | + mgpu.dialect.SwizzleTransformAttr.get(128), |
| 194 | + ]) |
| 195 | + |
| 196 | + self.assertSequenceEqual( |
| 197 | + inference_utils.in_transforms(vector_load_op), [expected_transforms] |
| 198 | + ) |
| 199 | + self.assertEmpty(inference_utils.out_transforms(vector_load_op)) |
| 200 | + |
| 201 | + def test_infer_transforms_for_vector_load_op_derives_from_source(self): |
| 202 | + vector_load_op = None |
| 203 | + shape = (64, 64) |
| 204 | + elt_ty = ir.BF16Type.get() |
| 205 | + |
| 206 | + def body(smem_ref): |
| 207 | + nonlocal vector_load_op |
| 208 | + zero = arith.constant(ir.IntegerType.get_signless(32), 0) |
| 209 | + vector_load_op = vector.LoadOp( |
| 210 | + ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape) |
| 211 | + ) |
| 212 | + |
| 213 | + with ir.InsertionPoint(self.module.body): |
| 214 | + smem = ir.Attribute.parse("#gpu.address_space<workgroup>") |
| 215 | + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) |
| 216 | + f = func.FuncOp.from_py_func(smem_ty)(body).func_op |
| 217 | + |
| 218 | + vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get( |
| 219 | + [layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))] |
| 220 | + ) |
| 221 | + transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))]) |
| 222 | + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) |
| 223 | + |
| 224 | + mgpu.infer_transforms(self.module) |
| 225 | + |
| 226 | + self.assertSequenceEqual( |
| 227 | + inference_utils.in_transforms(vector_load_op), [transforms] |
| 228 | + ) |
| 229 | + self.assertEmpty(inference_utils.out_transforms(vector_load_op)) |
| 230 | + |
| 231 | + def test_infer_transforms_for_vector_load_op_raises_on_mismatches(self): |
| 232 | + vector_load_op = None |
| 233 | + shape = (64, 64) |
| 234 | + elt_ty = ir.BF16Type.get() |
| 235 | + |
| 236 | + def body(smem_ref): |
| 237 | + nonlocal vector_load_op |
| 238 | + zero = arith.constant(ir.IntegerType.get_signless(32), 0) |
| 239 | + vector_load_op = vector.LoadOp( |
| 240 | + ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape) |
| 241 | + ) |
| 242 | + |
| 243 | + with ir.InsertionPoint(self.module.body): |
| 244 | + smem = ir.Attribute.parse("#gpu.address_space<workgroup>") |
| 245 | + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) |
| 246 | + f = func.FuncOp.from_py_func(smem_ty)(body).func_op |
| 247 | + |
| 248 | + vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get( |
| 249 | + [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] |
| 250 | + ) |
| 251 | + transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))]) |
| 252 | + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) |
| 253 | + |
| 254 | + with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): |
| 255 | + mgpu.infer_transforms(self.module) |
| 256 | + |
| 257 | + def test_infer_transforms_for_vector_store_op_derives_from_destination(self): |
| 258 | + vector_store_op = None |
| 259 | + shape = (64, 64) |
| 260 | + elt_ty = ir.BF16Type.get() |
| 261 | + |
| 262 | + def body(smem_ref, value_to_store): |
| 263 | + nonlocal vector_store_op |
| 264 | + zero = arith.constant(ir.IntegerType.get_signless(32), 0) |
| 265 | + vector_store_op = vector.StoreOp( |
| 266 | + value_to_store, smem_ref, [zero] * len(shape) |
| 267 | + ) |
| 268 | + |
| 269 | + with ir.InsertionPoint(self.module.body): |
| 270 | + smem = ir.Attribute.parse("#gpu.address_space<workgroup>") |
| 271 | + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) |
| 272 | + value_ty = ir.VectorType.get(shape, elt_ty) |
| 273 | + func.FuncOp.from_py_func(smem_ty, value_ty)(body) |
| 274 | + |
| 275 | + vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get( |
| 276 | + [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] |
| 277 | + ) |
| 278 | + |
| 279 | + mgpu.infer_transforms(self.module) |
| 280 | + |
| 281 | + expected_transforms = ir.ArrayAttr.get([ |
| 282 | + mgpu.dialect.TileTransformAttr.get((8, 64)), |
| 283 | + mgpu.dialect.SwizzleTransformAttr.get(128), |
| 284 | + ]) |
| 285 | + |
| 286 | + self.assertSequenceEqual( |
| 287 | + inference_utils.in_transforms(vector_store_op), [expected_transforms] |
| 288 | + ) |
| 289 | + self.assertEmpty(inference_utils.out_transforms(vector_store_op)) |
| 290 | + |
| 291 | + def test_infer_transforms_for_vector_store_op_derives_from_source(self): |
| 292 | + vector_store_op = None |
| 293 | + shape = (64, 64) |
| 294 | + elt_ty = ir.BF16Type.get() |
| 295 | + |
| 296 | + def body(smem_ref, value_to_store): |
| 297 | + nonlocal vector_store_op |
| 298 | + zero = arith.constant(ir.IntegerType.get_signless(32), 0) |
| 299 | + vector_store_op = vector.StoreOp( |
| 300 | + value_to_store, smem_ref, [zero] * len(shape) |
| 301 | + ) |
| 302 | + |
| 303 | + with ir.InsertionPoint(self.module.body): |
| 304 | + smem = ir.Attribute.parse("#gpu.address_space<workgroup>") |
| 305 | + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) |
| 306 | + value_ty = ir.VectorType.get(shape, elt_ty) |
| 307 | + f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op |
| 308 | + |
| 309 | + vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get( |
| 310 | + [layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))] |
| 311 | + ) |
| 312 | + transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))]) |
| 313 | + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) |
| 314 | + |
| 315 | + mgpu.infer_transforms(self.module) |
| 316 | + |
| 317 | + self.assertSequenceEqual( |
| 318 | + inference_utils.in_transforms(vector_store_op), [transforms] |
| 319 | + ) |
| 320 | + self.assertEmpty(inference_utils.out_transforms(vector_store_op)) |
| 321 | + |
| 322 | + def test_infer_transforms_for_vector_store_op_raises_on_mismatches(self): |
| 323 | + vector_store_op = None |
| 324 | + shape = (64, 64) |
| 325 | + elt_ty = ir.BF16Type.get() |
| 326 | + |
| 327 | + def body(smem_ref, value_to_store): |
| 328 | + nonlocal vector_store_op |
| 329 | + zero = arith.constant(ir.IntegerType.get_signless(32), 0) |
| 330 | + vector_store_op = vector.StoreOp( |
| 331 | + value_to_store, smem_ref, [zero] * len(shape) |
| 332 | + ) |
| 333 | + |
| 334 | + with ir.InsertionPoint(self.module.body): |
| 335 | + smem = ir.Attribute.parse("#gpu.address_space<workgroup>") |
| 336 | + smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem) |
| 337 | + value_ty = ir.VectorType.get(shape, elt_ty) |
| 338 | + f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op |
| 339 | + |
| 340 | + vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get( |
| 341 | + [layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)] |
| 342 | + ) |
| 343 | + transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))]) |
| 344 | + f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms]) |
| 345 | + |
| 346 | + with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"): |
| 347 | + mgpu.infer_transforms(self.module) |
| 348 | + |
165 | 349 |
|
166 | 350 | if __name__ == "__main__": |
167 | 351 | parameterized.absltest.main(testLoader=jtu.JaxTestLoader()) |
0 commit comments