2525from jax ._src .lib .mlir import ir
2626from jax ._src .lib .mlir .dialects import arith
2727from jax ._src .lib .mlir .dialects import builtin
28+ from jax ._src .lib .mlir .dialects import func
2829from jax ._src .lib .mlir .dialects import gpu
2930from jax ._src .lib .mlir .dialects import llvm
3031from jax ._src .lib .mlir .dialects import nvvm
32+ from jax ._src .lib .mlir .dialects import scf
3133from jax ._src .lib .mlir .dialects import vector
3234import numpy as np
3335
@@ -45,10 +47,37 @@ class LoweringContext:
4547 launch_context : launch_context .LaunchContext | None
4648 single_thread_per_block_predicate : ir .Value | None
4749 single_thread_per_warpgroup_predicate : ir .Value | None
50+ lowered_operations : set [ir .Operation | ir .OpView ] = dataclasses .field (
51+ default_factory = set
52+ )
53+
54+ def lower_op (self , op : ir .OpView ):
55+ if not _should_lower (op ):
56+ return
57+
58+ if (name := op .OPERATION_NAME ) not in _lowerings :
59+ raise NotImplementedError (f"Missing lowering rule for { op } " )
60+
61+ lowering_rule = _lowerings [name ]
62+
63+ # TODO(bchetioui): make sure all layouts are set here.
64+ if layouts .should_have_layout (op ) and not layouts .has_any_layout_set (op ):
65+ raise ValueError (f"{ op } is missing a layout and can not be lowered." )
66+
67+ new_results = lowering_rule (self , op )
68+ if new_results is not RECURSED :
69+ for old , new in zip (op .results , new_results ):
70+ old .replace_all_uses_with (new )
71+ self .lowered_operations .add (op )
4872
4973
74+ class Recursed :
75+ pass
76+ RECURSED = Recursed ()
77+
78+ MlirLoweringRuleResult = Sequence [ir .Value ] | Recursed
5079MlirLoweringRule = Callable [
51- [LoweringContext , ir .Operation | ir .OpView ], Sequence [ ir . Value ]
80+ [LoweringContext , ir .Operation | ir .OpView ], MlirLoweringRuleResult
5281]
5382
5483
@@ -544,6 +573,37 @@ def _mgpu_wait_op_lowering_rule(
544573 return []
545574
546575
576+ @_register_lowering (WaitOp )
577+ def _for_op_lowering_rule (
578+ _ : LoweringContext , wait_op : scf .ForOp
579+ ) -> Sequence [ir .Value ]:
580+
581+ barrier = utils .BarrierRef .from_dialect_barrier_memref (wait_op .barrier )
582+ barrier .wait_parity (wait_op .parity )
583+
584+ return []
585+
586+
587+ @_register_lowering (func .FuncOp )
588+ @_register_lowering (gpu .LaunchOp )
589+ @_register_lowering (scf .IfOp ) # TODO(apaszke,bchetioui): Add a proper rule.
590+ @_register_lowering (scf .ForOp ) # TODO(apaszke,bchetioui): Add a proper rule.
591+ @_register_lowering (scf .IndexSwitchOp ) # TODO(apaszke,bchetioui): Add a proper rule.
592+ def _traverse_op_lowering_rule (
593+ ctx : LoweringContext , op : ir .OpView
594+ ) -> MlirLoweringRuleResult :
595+ if layouts .should_have_layout (op ):
596+ raise ValueError (
597+ f"Rule cannot handle an op with vector operands or results: { op } "
598+ )
599+ for region in op .operation .regions :
600+ for block in region :
601+ for block_op in list (block ):
602+ with ir .InsertionPoint (block_op ):
603+ ctx .lower_op (block_op )
604+ return RECURSED
605+
606+
547607def single_thread_predicates (module : ir .Module ) -> tuple [ir .Value , ir .Value ]:
548608 """Returns a single thread predicate per block and one per warpgroup."""
549609 block_predicate = warpgroup_predicate = None
@@ -572,11 +632,11 @@ def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]:
572632
573633def _should_lower (op : ir .OpView ) -> bool :
574634 """Returns 'true' if the operation should be lowered."""
575- if isinstance ( op . name , ir . StringAttr ):
576- name = op .name . value
577- else :
578- name = op .name
579- return name . startswith ( "mosaic_gpu." ) or layouts . should_have_layout ( op )
635+ return (
636+ op .OPERATION_NAME . startswith ( "mosaic_gpu." )
637+ or layouts . should_have_layout ( op )
638+ or any ( bool ( b ) for r in op .regions for b in r ) # Does it have subblocks?
639+ )
580640
581641
582642def lower_mgpu_dialect (
@@ -591,46 +651,16 @@ def lower_mgpu_dialect(
591651 module .context .append_dialect_registry (mlir_interpreter .upstream_dialects )
592652 module .context .load_all_available_dialects ()
593653
594- lowered_operations : set [ir .Operation | ir .OpView ] = set ()
595-
596654 # TODO(bchetioui): fix tests to not have a test-only path polluting the API.
597655 if launch_context is None : # this case is used in some tests
598656 block_predicate = warpgroup_predicate = None
599657 else :
600658 block_predicate , warpgroup_predicate = single_thread_predicates (module )
601659
602660 ctx = LoweringContext (launch_context , block_predicate , warpgroup_predicate )
603-
604- def _lower_op (op : ir .OpView ):
605- if not _should_lower (op ):
606- return
607-
608- if op .name not in _lowerings :
609- raise NotImplementedError (f"Missing lowering rule for { op .name } " )
610-
611- lowering_rule = _lowerings [op .name ]
612-
613- # TODO(bchetioui): make sure all layouts are set here.
614- if layouts .should_have_layout (op ) and not layouts .has_any_layout_set (op ):
615- raise ValueError (f"{ op } is missing a layout and can not be lowered." )
616-
617- new_results = lowering_rule (ctx , op )
618-
619- for old , new in zip (op .results , new_results ):
620- old .replace_all_uses_with (new )
621- lowered_operations .add (op )
622-
623- def _traverse_and_lower_op (op : ir .OpView ):
624- for region in op .operation .regions :
625- for block in region :
626- for block_op in list (block ):
627- with ir .InsertionPoint (block_op ):
628- _traverse_and_lower_op (block_op )
629- _lower_op (op )
630-
631661 with ir .InsertionPoint (module .body ):
632662 for op in list (module .body ):
633- _traverse_and_lower_op (op )
663+ ctx . lower_op (op )
634664
635- for lowered_op in lowered_operations :
665+ for lowered_op in ctx . lowered_operations :
636666 lowered_op .erase ()
0 commit comments