Skip to content

Commit 8a7bf2e

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Ensure that lowering InitializeBarrierOp preserves the result's type.
Otherwise, the lowered IR won't be type-correct. PiperOrigin-RevId: 695339726
1 parent 1d24630 commit 8a7bf2e

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from jaxlib.mlir.dialects import gpu
2727
from jaxlib.mlir.dialects import llvm
2828
from jaxlib.mlir.dialects import nvvm
29-
from .utils import c, single_thread_predicate
29+
from .utils import c, ptr_as_memref, single_thread_predicate
3030

3131
# mypy: ignore-errors
3232

@@ -89,7 +89,12 @@ def _initialize_barrier_op_lowering_rule(
8989
predicate=predicate
9090
)
9191

92-
return initialize_barrier_op.base_pointer,
92+
barrier_base_ptr = llvm.getelementptr(
93+
ir.Type.parse("!llvm.ptr"),
94+
initialize_barrier_op.base_pointer, [], [0], lowered_barrier_type)
95+
96+
return ptr_as_memref(
97+
barrier_base_ptr, initialize_barrier_op.barriers_ref.type),
9398

9499

95100
def lower_mgpu_dialect(module: ir.Module):

tests/mosaic/gpu_dialect_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from jax._src.lib.mlir.dialects import func
2626
from jax._src.lib.mlir.dialects import gpu
2727
from jax._src.lib.mlir.dialects import llvm
28+
from jax._src.lib.mlir.dialects import memref
2829
from jax._src.lib.mlir.dialects import nvvm
2930
from jax._src.lib.mlir.dialects import scf
3031
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
@@ -512,11 +513,17 @@ def test_initialize_barrier_op_lowering_rule(self):
512513
arrival_count = 1337
513514

514515
with ir.InsertionPoint(self.module.body):
515-
mgpu.initialize_barrier(
516+
barriers_ref = mgpu.initialize_barrier(
516517
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
517518
llvm.UndefOp(workgroup_ptr_ty()),
518519
arrival_count=arrival_count)
520+
# Add a user for barriers_ref to make sure that the lowering keeps types
521+
# consistent.
522+
memref.copy(barriers_ref, barriers_ref)
523+
524+
self.assertTrue(self.module.operation.verify())
519525
lower_mgpu_dialect(self.module)
526+
self.assertTrue(self.module.operation.verify())
520527

521528
all_mbarrier_init_shared_ops = find_if(
522529
self.module,

0 commit comments

Comments
 (0)