Skip to content

Commit 1f1d27d

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Implement the skeleton of a lowering pass for the Mosaic GPU dialect.
Also add a lowering rule for `mosaic_gpu.initialize_barrier`. PiperOrigin-RevId: 694276698
1 parent 0bb30f0 commit 1f1d27d

File tree

3 files changed

+216
-1
lines changed

3 files changed

+216
-1
lines changed

jax/experimental/mosaic/gpu/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
Union as Union,
2828
as_gpu_kernel as as_gpu_kernel,
2929
)
30+
31+
if dialect is not None:
32+
from .dialect_lowering import lower_mgpu_dialect
33+
else:
34+
lower_mgpu_dialect = None
35+
3036
from .fragmented_array import (
3137
FragmentedArray as FragmentedArray,
3238
FragmentedLayout as FragmentedLayout,
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Lowering rules and pass for the MLIR Mosaic GPU dialect."""
16+
17+
from collections.abc import Callable
18+
import functools
19+
import operator
20+
from typing import Sequence, Type
21+
22+
from jax._src.interpreters import mlir as mlir_interpreter
23+
from jax._src.lib import mosaic_gpu_dialect as mgpu
24+
25+
from jaxlib.mlir import ir
26+
from jaxlib.mlir.dialects import gpu
27+
from jaxlib.mlir.dialects import llvm
28+
from jaxlib.mlir.dialects import memref
29+
from jaxlib.mlir.dialects import nvvm
30+
from .utils import c, memref_ptr, single_thread_predicate
31+
32+
33+
MlirLoweringRule = Callable[[ir.Operation | ir.OpView], Sequence[ir.Value]]
34+
35+
36+
_lowerings: dict[str, MlirLoweringRule] = {}
37+
38+
39+
# TODO(bchetioui): Remove this when minimum jaxlib version >= 0.4.36.
40+
# Jaxlib doesn't contain Mosaic GPU dialect bindings.
41+
InitializeBarrierOp = mgpu.InitializeBarrierOp if mgpu is not None else None
42+
43+
def _register_lowering(
44+
op: str | Type[ir.OpView]
45+
) -> Callable[[MlirLoweringRule], MlirLoweringRule]:
46+
def wrapper(f):
47+
op_name = op if isinstance(op, str) else op.OPERATION_NAME # pytype: disable=attribute-error
48+
_lowerings[op_name] = f
49+
return f
50+
51+
return wrapper
52+
53+
54+
def _lowered_barrier_type() -> ir.Type:
55+
return ir.IntegerType.get_signless(64)
56+
57+
58+
def _gpu_address_space_to_nvptx(address_space: gpu.AddressSpace) -> int:
59+
match address_space:
60+
case gpu.AddressSpace.Global:
61+
return 1
62+
case gpu.AddressSpace.Workgroup:
63+
return 3
64+
case _:
65+
raise NotImplementedError(f"address_space not supported: {address_space}")
66+
67+
68+
@_register_lowering(InitializeBarrierOp)
69+
def _initialize_barrier_op_lowering_rule(
70+
initialize_barrier_op: InitializeBarrierOp) -> Sequence[ir.Value]:
71+
72+
shape = initialize_barrier_op.barriers_ref.type.shape
73+
num_barriers = functools.reduce(operator.mul, shape, 1)
74+
75+
i32 = ir.IntegerType.get_signless(32)
76+
workgroup_nvptx_address_space = _gpu_address_space_to_nvptx(
77+
gpu.AddressSpace.Workgroup)
78+
ptr_ty = ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")
79+
80+
lowered_barrier_type = _lowered_barrier_type()
81+
lowered_barrier_ref = memref.alloca(
82+
ir.MemRefType.get(shape, lowered_barrier_type), [], [])
83+
barrier_ref_address = memref_ptr(
84+
lowered_barrier_ref, memory_space=workgroup_nvptx_address_space)
85+
86+
predicate = single_thread_predicate(per_block=True)
87+
for i in range(num_barriers):
88+
nvvm.mbarrier_init_shared(
89+
llvm.getelementptr(ptr_ty, barrier_ref_address, [], [i],
90+
lowered_barrier_type),
91+
c(initialize_barrier_op.arrival_count.value, i32),
92+
predicate=predicate
93+
)
94+
return barrier_ref_address,
95+
96+
97+
def lower_mgpu_dialect(module: ir.Module):
98+
module.context.append_dialect_registry(mlir_interpreter.upstream_dialects)
99+
module.context.load_all_available_dialects()
100+
101+
lowered_operations: set[ir.Operation | ir.OpView] = set()
102+
103+
def _lower_op(op: ir.OpView):
104+
if op.name not in _lowerings:
105+
return
106+
lowering_rule = _lowerings[op.name]
107+
new_results = lowering_rule(op)
108+
for old, new in zip(op.results, new_results):
109+
old.replace_all_uses_with(new)
110+
lowered_operations.add(op)
111+
112+
def _traverse_and_lower_op(op: ir.OpView):
113+
for region in op.operation.regions:
114+
for block in region:
115+
for block_op in list(block):
116+
with ir.InsertionPoint(block_op):
117+
_traverse_and_lower_op(block_op)
118+
_lower_op(op)
119+
120+
with ir.InsertionPoint(module.body):
121+
for op in module.body:
122+
_traverse_and_lower_op(op)
123+
124+
for lowered_op in lowered_operations:
125+
lowered_op.erase()

tests/mosaic/gpu_dialect_test.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,19 @@
1414
# ==============================================================================
1515
"""(Deviceless) tests for the Mosaic GPU MLIR dialect."""
1616

17+
from typing import Callable
18+
1719
from absl.testing import parameterized
1820
from jax._src import config
1921
from jax._src import test_util as jtu
22+
from jax._src.interpreters import mlir as mlir_interpreter
2023
from jax._src.lib.mlir import ir
21-
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
24+
from jax._src.lib.mlir.dialects import arith
25+
from jax._src.lib.mlir.dialects import nvvm
26+
from jax._src.lib.mlir.dialects import scf
2227

28+
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
29+
from jax.experimental.mosaic.gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import
2330

2431
_cext = mgpu._cext if mgpu is not None else None
2532

@@ -29,10 +36,35 @@
2936

3037
def _make_ir_context():
3138
context = ir.Context()
39+
context.append_dialect_registry(mlir_interpreter.upstream_dialects)
40+
context.load_all_available_dialects()
3241
mgpu.register_dialect(context)
3342
return context
3443

3544

45+
def walk_operations(op: ir.OpView, callback):
46+
for region in op.operation.regions:
47+
for block in region:
48+
for block_op in block:
49+
walk_operations(block_op, callback)
50+
callback(op)
51+
52+
53+
def find_if(module: ir.Module,
54+
predicate: Callable[[ir.OpView], bool]) -> list[ir.OpView]:
55+
result = []
56+
def callback(op: ir.OpView):
57+
if predicate(op):
58+
result.append(op)
59+
for op in module.body.operations:
60+
walk_operations(op, callback)
61+
return result
62+
63+
64+
def is_mosaic_gpu_op(op: ir.OpView) -> bool:
65+
return op.name.startswith("mosaic_gpu.")
66+
67+
3668
class DialectTest(parameterized.TestCase):
3769

3870
def setUp(self):
@@ -72,5 +104,57 @@ def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self):
72104
mgpu.InitializeBarrierOp)
73105

74106

107+
class DialectLoweringTest(DialectTest):
108+
109+
def test_lowering_removes_mosaic_gpu_ops(self):
110+
with ir.InsertionPoint(self.module.body):
111+
mgpu.initialize_barrier(
112+
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
113+
arrival_count=1)
114+
lower_mgpu_dialect(self.module)
115+
116+
self.assertEmpty(
117+
list(filter(is_mosaic_gpu_op, self.module.body.operations)))
118+
119+
def test_lowering_traverses_regions_correctly(self):
120+
with ir.InsertionPoint(self.module.body):
121+
bool_type = ir.IntegerType.get_signless(1)
122+
cst_true = arith.constant(bool_type, ir.IntegerAttr.get(bool_type, 1))
123+
if_op = scf.IfOp(cst_true)
124+
with ir.InsertionPoint(if_op.then_block):
125+
mgpu.initialize_barrier(
126+
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
127+
arrival_count=1)
128+
scf.yield_([])
129+
lower_mgpu_dialect(self.module)
130+
131+
self.assertEmpty(
132+
list(filter(is_mosaic_gpu_op, if_op.then_block.operations)))
133+
134+
def test_initialize_barrier_op_lowering_rule(self):
135+
shape = (3, 4)
136+
num_shape_elements = shape[0] * shape[1]
137+
arrival_count = 1337
138+
139+
with ir.InsertionPoint(self.module.body):
140+
mgpu.initialize_barrier(
141+
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
142+
arrival_count=arrival_count)
143+
lower_mgpu_dialect(self.module)
144+
145+
all_mbarrier_init_shared_ops = find_if(
146+
self.module,
147+
lambda op: op.name == nvvm.MBarrierInitSharedOp.OPERATION_NAME)
148+
149+
# One nvvm.mbarrier_init_shared is issued per barrier.
150+
self.assertLen(all_mbarrier_init_shared_ops, num_shape_elements)
151+
152+
# Each barrier has its count equal to the arrival count.
153+
for op in all_mbarrier_init_shared_ops:
154+
count = op.count.owner.opview
155+
self.assertIsInstance(count, arith.ConstantOp)
156+
self.assertEqual(count.literal_value, arrival_count)
157+
158+
75159
if __name__ == "__main__":
76160
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)