Skip to content

Commit ba2f7c9

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add transform inference rule for mgpu.slice_smem.
PiperOrigin-RevId: 737957778
1 parent d4bd257 commit ba2f7c9

File tree

2 files changed

+100
-1
lines changed

2 files changed

+100
-1
lines changed

jax/experimental/mosaic/gpu/transform_inference.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
def _add_transform_inference_rule(
4444
op: type[ir.OpView], rule: TransformInferenceRule
4545
):
46-
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
46+
if op is not None:
47+
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
4748
return rule
4849

4950

@@ -169,6 +170,32 @@ def _infer_vector_load_store_transforms(
169170
return None
170171

171172

173+
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
174+
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
175+
176+
@partial(_add_transform_inference_rule, SliceSMEMOp)
177+
def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
178+
transforms = None
179+
uses = cast(ir.OpResult, op.result).uses
180+
181+
for op_operand_use in uses:
182+
consumer = op_operand_use.owner
183+
op_user = consumer.operands[op_operand_use.operand_number]
184+
out_transforms = inference_utils.in_transforms_for_operand(
185+
consumer, op_user
186+
)
187+
if transforms is not None and out_transforms is not None:
188+
if transforms != out_transforms:
189+
raise NotImplementedError(
190+
f"Conflicting transforms for {op_user} in {op}: "
191+
f"{transforms} != {out_transforms}."
192+
)
193+
elif out_transforms is not None:
194+
transforms = out_transforms
195+
196+
return None if transforms is None else ([], [transforms])
197+
198+
172199
def _should_have_transforms(op: ir.OpView) -> bool:
173200
"""Returns 'True' if the operation should be assigned in/out transforms."""
174201
return any(

tests/mosaic/gpu_transform_inference_test.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,78 @@ def body(smem_ref, value_to_store):
346346
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
347347
mgpu.infer_transforms(self.module)
348348

349+
def test_infer_transforms_for_slice_smem_op_derives_from_user(self):
350+
slice_smem_op = vector_load_op = None
351+
shape = (64, 64)
352+
elt_ty = ir.BF16Type.get()
353+
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
354+
355+
def body(offset):
356+
nonlocal slice_smem_op, vector_load_op
357+
slice_smem_op = mgpu.dialect.SliceSMEMOp(
358+
ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset
359+
)
360+
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
361+
load_offsets = [zero] * len(shape)
362+
vector_load_op = vector.LoadOp(
363+
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
364+
)
365+
366+
with ir.InsertionPoint(self.module.body):
367+
func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body)
368+
369+
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
370+
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
371+
)
372+
373+
mgpu.infer_transforms(self.module)
374+
375+
expected_transforms = ir.ArrayAttr.get([
376+
mgpu.dialect.TileTransformAttr.get((8, 64)),
377+
mgpu.dialect.SwizzleTransformAttr.get(128),
378+
])
379+
380+
self.assertEmpty(inference_utils.in_transforms(slice_smem_op))
381+
self.assertSequenceEqual(
382+
inference_utils.out_transforms(slice_smem_op), [expected_transforms]
383+
)
384+
385+
def test_infer_transforms_for_slice_smem_op_raises_on_mismatches(self):
386+
slice_smem_op = vector_load_op1 = vector_load_op2 = None
387+
shape = (64, 64)
388+
elt_ty = ir.BF16Type.get()
389+
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
390+
391+
def body(offset):
392+
nonlocal slice_smem_op, vector_load_op1, vector_load_op2
393+
slice_smem_op = mgpu.dialect.SliceSMEMOp(
394+
ir.MemRefType.get(shape, elt_ty, memory_space=smem), offset
395+
)
396+
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
397+
load_offsets = [zero] * len(shape)
398+
vector_load_op1 = vector.LoadOp(
399+
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
400+
)
401+
vector_load_op2 = vector.LoadOp(
402+
ir.VectorType.get(shape, elt_ty), slice_smem_op.result, load_offsets
403+
)
404+
405+
with ir.InsertionPoint(self.module.body):
406+
func.FuncOp.from_py_func(ir.IntegerType.get_signless(32))(body)
407+
408+
vector_load_op1.attributes["out_layouts"] = ir.ArrayAttr.get(
409+
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
410+
)
411+
vector_load_op2.attributes["out_layouts"] = ir.ArrayAttr.get(
412+
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
413+
)
414+
vector_load_op2.attributes["in_transforms"] = ir.ArrayAttr.get(
415+
[ir.ArrayAttr.get([mgpu.dialect.TransposeTransformAttr.get((1, 0))])]
416+
)
417+
418+
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
419+
mgpu.infer_transforms(self.module)
420+
349421

350422
if __name__ == "__main__":
351423
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)