Skip to content

Commit ced2816

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Use explicit recursion in rules instead of doing it automatically
Control-flow ops that have vector inputs or outputs will need to be specially adjusted. PiperOrigin-RevId: 730922072
1 parent eb912ad commit ced2816

File tree

1 file changed

+68
-38
lines changed

1 file changed

+68
-38
lines changed

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@
2525
from jax._src.lib.mlir import ir
2626
from jax._src.lib.mlir.dialects import arith
2727
from jax._src.lib.mlir.dialects import builtin
28+
from jax._src.lib.mlir.dialects import func
2829
from jax._src.lib.mlir.dialects import gpu
2930
from jax._src.lib.mlir.dialects import llvm
3031
from jax._src.lib.mlir.dialects import nvvm
32+
from jax._src.lib.mlir.dialects import scf
3133
from jax._src.lib.mlir.dialects import vector
3234
import numpy as np
3335

@@ -45,10 +47,37 @@ class LoweringContext:
4547
launch_context: launch_context.LaunchContext | None
4648
single_thread_per_block_predicate: ir.Value | None
4749
single_thread_per_warpgroup_predicate: ir.Value | None
50+
lowered_operations: set[ir.Operation | ir.OpView] = dataclasses.field(
51+
default_factory=set
52+
)
53+
54+
def lower_op(self, op: ir.OpView):
55+
if not _should_lower(op):
56+
return
57+
58+
if (name := op.OPERATION_NAME) not in _lowerings:
59+
raise NotImplementedError(f"Missing lowering rule for {op}")
60+
61+
lowering_rule = _lowerings[name]
62+
63+
# TODO(bchetioui): make sure all layouts are set here.
64+
if layouts.should_have_layout(op) and not layouts.has_any_layout_set(op):
65+
raise ValueError(f"{op} is missing a layout and can not be lowered.")
66+
67+
new_results = lowering_rule(self, op)
68+
if new_results is not RECURSED:
69+
for old, new in zip(op.results, new_results):
70+
old.replace_all_uses_with(new)
71+
self.lowered_operations.add(op)
4872

4973

74+
class Recursed:
75+
pass
76+
RECURSED = Recursed()
77+
78+
MlirLoweringRuleResult = Sequence[ir.Value] | Recursed
5079
MlirLoweringRule = Callable[
51-
[LoweringContext, ir.Operation | ir.OpView], Sequence[ir.Value]
80+
[LoweringContext, ir.Operation | ir.OpView], MlirLoweringRuleResult
5281
]
5382

5483

@@ -544,6 +573,37 @@ def _mgpu_wait_op_lowering_rule(
544573
return []
545574

546575

576+
@_register_lowering(WaitOp)
577+
def _for_op_lowering_rule(
578+
_: LoweringContext, wait_op: scf.ForOp
579+
) -> Sequence[ir.Value]:
580+
581+
barrier = utils.BarrierRef.from_dialect_barrier_memref(wait_op.barrier)
582+
barrier.wait_parity(wait_op.parity)
583+
584+
return []
585+
586+
587+
@_register_lowering(func.FuncOp)
588+
@_register_lowering(gpu.LaunchOp)
589+
@_register_lowering(scf.IfOp) # TODO(apaszke,bchetioui): Add a proper rule.
590+
@_register_lowering(scf.ForOp) # TODO(apaszke,bchetioui): Add a proper rule.
591+
@_register_lowering(scf.IndexSwitchOp) # TODO(apaszke,bchetioui): Add a proper rule.
592+
def _traverse_op_lowering_rule(
593+
ctx: LoweringContext, op: ir.OpView
594+
) -> MlirLoweringRuleResult:
595+
if layouts.should_have_layout(op):
596+
raise ValueError(
597+
f"Rule cannot handle an op with vector operands or results: {op}"
598+
)
599+
for region in op.operation.regions:
600+
for block in region:
601+
for block_op in list(block):
602+
with ir.InsertionPoint(block_op):
603+
ctx.lower_op(block_op)
604+
return RECURSED
605+
606+
547607
def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]:
548608
"""Returns a single thread predicate per block and one per warpgroup."""
549609
block_predicate = warpgroup_predicate = None
@@ -572,11 +632,11 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]:
572632

573633
def _should_lower(op: ir.OpView) -> bool:
574634
"""Returns 'true' if the operation should be lowered."""
575-
if isinstance(op.name, ir.StringAttr):
576-
name = op.name.value
577-
else:
578-
name = op.name
579-
return name.startswith("mosaic_gpu.") or layouts.should_have_layout(op)
635+
return (
636+
op.OPERATION_NAME.startswith("mosaic_gpu.")
637+
or layouts.should_have_layout(op)
638+
or any(bool(b) for r in op.regions for b in r) # Does it have subblocks?
639+
)
580640

581641

582642
def lower_mgpu_dialect(
@@ -591,46 +651,16 @@ def lower_mgpu_dialect(
591651
module.context.append_dialect_registry(mlir_interpreter.upstream_dialects)
592652
module.context.load_all_available_dialects()
593653

594-
lowered_operations: set[ir.Operation | ir.OpView] = set()
595-
596654
# TODO(bchetioui): fix tests to not have a test-only path polluting the API.
597655
if launch_context is None: # this case is used in some tests
598656
block_predicate = warpgroup_predicate = None
599657
else:
600658
block_predicate, warpgroup_predicate = single_thread_predicates(module)
601659

602660
ctx = LoweringContext(launch_context, block_predicate, warpgroup_predicate)
603-
604-
def _lower_op(op: ir.OpView):
605-
if not _should_lower(op):
606-
return
607-
608-
if op.name not in _lowerings:
609-
raise NotImplementedError(f"Missing lowering rule for {op.name}")
610-
611-
lowering_rule = _lowerings[op.name]
612-
613-
# TODO(bchetioui): make sure all layouts are set here.
614-
if layouts.should_have_layout(op) and not layouts.has_any_layout_set(op):
615-
raise ValueError(f"{op} is missing a layout and can not be lowered.")
616-
617-
new_results = lowering_rule(ctx, op)
618-
619-
for old, new in zip(op.results, new_results):
620-
old.replace_all_uses_with(new)
621-
lowered_operations.add(op)
622-
623-
def _traverse_and_lower_op(op: ir.OpView):
624-
for region in op.operation.regions:
625-
for block in region:
626-
for block_op in list(block):
627-
with ir.InsertionPoint(block_op):
628-
_traverse_and_lower_op(block_op)
629-
_lower_op(op)
630-
631661
with ir.InsertionPoint(module.body):
632662
for op in list(module.body):
633-
_traverse_and_lower_op(op)
663+
ctx.lower_op(op)
634664

635-
for lowered_op in lowered_operations:
665+
for lowered_op in ctx.lowered_operations:
636666
lowered_op.erase()

0 commit comments

Comments
 (0)