Skip to content

Commit 4ef7706

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Split layout inference and dialect lowering files and tests.
PiperOrigin-RevId: 705100503
1 parent 354bd52 commit 4ef7706

File tree

6 files changed

+321
-241
lines changed

6 files changed

+321
-241
lines changed

jax/experimental/mosaic/gpu/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@
3131
if dialect is not None:
3232
from .dialect_lowering import (
3333
gpu_address_space_to_nvptx as gpu_address_space_to_nvptx,
34-
infer_layout,
3534
lower_mgpu_dialect as lower_mgpu_dialect,
36-
splat_fragmented_layout,
37-
strided_fragmented_layout,
35+
)
36+
from .layout_inference import (
37+
infer_layout as infer_layout,
38+
splat_fragmented_layout as splat_fragmented_layout,
39+
strided_fragmented_layout as strided_fragmented_layout,
3840
)
3941
else:
4042
gpu_address_space_to_nvptx = None

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 1 addition & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@
1515
"""Lowering rules and pass for the MLIR Mosaic GPU dialect."""
1616

1717
from collections.abc import Callable
18-
import enum
1918
import functools
20-
import itertools
2119
import operator
22-
from typing import List, Sequence, Tuple, Type, cast
20+
from typing import Sequence, Type
2321

2422
from jax._src.interpreters import mlir as mlir_interpreter
2523
from jax._src.lib import mosaic_gpu_dialect as mgpu
2624
from jax._src.lib.mlir import ir
27-
from jax._src.lib.mlir.dialects import arith
2825
from jax._src.lib.mlir.dialects import gpu
2926
from jax._src.lib.mlir.dialects import llvm
3027
from jax._src.lib.mlir.dialects import nvvm
@@ -34,169 +31,6 @@
3431
# mypy: ignore-errors
3532

3633

37-
def strided_fragmented_layout():
38-
layout = mgpu.FragmentedLayout.WGStridedFragLayout
39-
return ir.Attribute.parse(f"#mosaic_gpu.fragmented_layout<{layout}>")
40-
41-
42-
def splat_fragmented_layout():
43-
layout = mgpu.FragmentedLayout.WGSplatFragLayout
44-
return ir.Attribute.parse(f"#mosaic_gpu.fragmented_layout<{layout}>")
45-
46-
47-
_layout_inference_rules: dict[
48-
str,
49-
Callable[[ir.OpView], Tuple[List[ir.Attribute], List[ir.Attribute]] | None],
50-
] = {}
51-
52-
53-
def _add_layout_inference_rule(
54-
op: Type[ir.OpView],
55-
rule: Callable[
56-
[ir.OpView], Tuple[List[ir.Attribute], List[ir.Attribute]] | None
57-
],
58-
):
59-
_layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
60-
61-
62-
def _set_layout_attributes(
63-
op: ir.OpView,
64-
in_layouts: List[ir.Attribute],
65-
out_layouts: List[ir.Attribute],
66-
):
67-
op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts)
68-
op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts)
69-
70-
71-
def _extract_any_layout_from_op(op: ir.OpView) -> ir.Attribute | None:
72-
if "in_layouts" in op.attributes and len(op.operands) > 0:
73-
return cast(ir.ArrayAttr, op.attributes["in_layouts"])[0]
74-
elif "out_layouts" in op.attributes and len(op.results) > 0:
75-
return cast(ir.ArrayAttr, op.attributes["out_layouts"])[0]
76-
77-
return None
78-
79-
80-
def _infer_pointwise_op_layouts(
81-
op: ir.OpView,
82-
) -> Tuple[List[ir.Attribute], List[ir.Attribute]] | None:
83-
layout = _extract_any_layout_from_op(op)
84-
# The op had no layout set. Since we're annotating ops, we may need to
85-
# derive layout information from user or producer ops.
86-
if layout is None:
87-
# First, we iterate on users.
88-
for op_result in op.results:
89-
for op_user in cast(ir.OpResult, op_result).uses:
90-
layout = _extract_any_layout_from_op(op_user.owner)
91-
if layout:
92-
break
93-
else:
94-
continue
95-
break
96-
97-
if layout is None:
98-
# Still no layout set. We iterate on producers.
99-
for operand in op.operands:
100-
layout = _extract_any_layout_from_op(operand.owner)
101-
if layout:
102-
break
103-
104-
if layout is None:
105-
return None
106-
107-
return ([layout for _ in op.operands], [layout for _ in op.results])
108-
109-
110-
for op in (
111-
arith.AddFOp,
112-
arith.ConstantOp,
113-
arith.MulFOp,
114-
):
115-
_add_layout_inference_rule(op, _infer_pointwise_op_layouts)
116-
117-
118-
def _layout_inference_should_process_op(op: ir.OpView) -> bool:
119-
"""Returns 'true' if the layout inference pass can skip the operation."""
120-
121-
def is_array(v: ir.Value):
122-
ty = v.type
123-
return ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty)
124-
125-
return any(map(is_array, itertools.chain(op.operands, op.results)))
126-
127-
128-
def _has_any_layout_set(op: ir.OpView) -> bool:
129-
return "in_layouts" in op.attributes or "out_layouts" in op.attributes
130-
131-
132-
class TraversalOrder(enum.Enum):
133-
"""Traversal orders with respect to the data flow for IR."""
134-
135-
FORWARD = 1
136-
BACKWARDS = 2
137-
138-
139-
def traverse_op(
140-
op: ir.OpView,
141-
callback: Callable[[ir.OpView], None],
142-
traversal_order: TraversalOrder = TraversalOrder.FORWARD,
143-
):
144-
"""Traverses the operation and applies the callback in the given order."""
145-
for region in op.operation.regions:
146-
for block in region:
147-
if traversal_order == TraversalOrder.FORWARD:
148-
ops_to_traverse = block
149-
else:
150-
ops_to_traverse = reversed(list(block))
151-
for block_op in ops_to_traverse:
152-
callback(block_op)
153-
callback(op)
154-
155-
156-
def infer_layout(module: ir.Module):
157-
def inference_step(op: ir.Operation):
158-
if not _layout_inference_should_process_op(op):
159-
return
160-
elif inference_rule := _layout_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error
161-
pass
162-
else:
163-
raise NotImplementedError(f"Can not infer layout for {op}")
164-
165-
maybe_layouts = inference_rule(op)
166-
if maybe_layouts is None:
167-
return
168-
169-
_set_layout_attributes(op, *maybe_layouts)
170-
171-
# We run two passes over the module, in order to make sure that layouts
172-
# defined in the middle of the computation are propagated wherever they need
173-
# to be propagated. We start with a backwards (root-to-parameters) pass to
174-
# propagate the information as far up as possible, and then a forward pass
175-
# (parameters-to-root).
176-
#
177-
# Backwards pass
178-
for op in module.body:
179-
traverse_op(op, inference_step, TraversalOrder.BACKWARDS)
180-
181-
# Forward pass
182-
for op in module.body:
183-
traverse_op(op, inference_step, TraversalOrder.FORWARD)
184-
185-
# At this point, layouts have been propagated as far as they could be
186-
# propagated. However, it is possible for some operations to remain
187-
# unannotated---for example, if there were no annotations on any operation in
188-
# the module at the start of this function. We annotate all the remaining ops
189-
# that should be annotated with a strided fragmented layout.
190-
def set_default_layout(op: ir.OpView):
191-
layout = strided_fragmented_layout()
192-
if _layout_inference_should_process_op(op) and not _has_any_layout_set(op):
193-
_set_layout_attributes(
194-
op, [layout] * len(op.operands), [layout] * len(op.results))
195-
196-
for op in module.body:
197-
traverse_op(op, set_default_layout)
198-
199-
20034
MlirLoweringRule = Callable[[ir.Operation | ir.OpView], Sequence[ir.Value]]
20135

20236

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Layout inference pass for the MLIR Mosaic GPU dialect."""
16+
17+
from collections.abc import Callable
18+
import enum
19+
import itertools
20+
from typing import List, Tuple, Type, cast
21+
22+
from jax._src.lib import mosaic_gpu_dialect as mgpu
23+
from jax._src.lib.mlir import ir
24+
from jax._src.lib.mlir.dialects import arith
25+
26+
# mypy: ignore-errors
27+
28+
29+
def strided_fragmented_layout():
30+
layout = mgpu.FragmentedLayout.WGStridedFragLayout
31+
return ir.Attribute.parse(f"#mosaic_gpu.fragmented_layout<{layout}>")
32+
33+
34+
def splat_fragmented_layout():
35+
layout = mgpu.FragmentedLayout.WGSplatFragLayout
36+
return ir.Attribute.parse(f"#mosaic_gpu.fragmented_layout<{layout}>")
37+
38+
39+
_layout_inference_rules: dict[
40+
str,
41+
Callable[[ir.OpView], Tuple[List[ir.Attribute], List[ir.Attribute]] | None],
42+
] = {}
43+
44+
45+
def _add_layout_inference_rule(
46+
op: Type[ir.OpView],
47+
rule: Callable[
48+
[ir.OpView], Tuple[List[ir.Attribute], List[ir.Attribute]] | None
49+
],
50+
):
51+
_layout_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
52+
53+
54+
def _set_layout_attributes(
55+
op: ir.OpView,
56+
in_layouts: List[ir.Attribute],
57+
out_layouts: List[ir.Attribute],
58+
):
59+
op.attributes["in_layouts"] = ir.ArrayAttr.get(in_layouts)
60+
op.attributes["out_layouts"] = ir.ArrayAttr.get(out_layouts)
61+
62+
63+
def _extract_any_layout_from_op(op: ir.OpView) -> ir.Attribute | None:
64+
if "in_layouts" in op.attributes and len(op.operands) > 0:
65+
return cast(ir.ArrayAttr, op.attributes["in_layouts"])[0]
66+
elif "out_layouts" in op.attributes and len(op.results) > 0:
67+
return cast(ir.ArrayAttr, op.attributes["out_layouts"])[0]
68+
69+
return None
70+
71+
72+
def _infer_pointwise_op_layouts(
73+
op: ir.OpView,
74+
) -> Tuple[List[ir.Attribute], List[ir.Attribute]] | None:
75+
layout = _extract_any_layout_from_op(op)
76+
# The op had no layout set. Since we're annotating ops, we may need to
77+
# derive layout information from user or producer ops.
78+
if layout is None:
79+
# First, we iterate on users.
80+
for op_result in op.results:
81+
for op_user in cast(ir.OpResult, op_result).uses:
82+
layout = _extract_any_layout_from_op(op_user.owner)
83+
if layout:
84+
break
85+
else:
86+
continue
87+
break
88+
89+
if layout is None:
90+
# Still no layout set. We iterate on producers.
91+
for operand in op.operands:
92+
layout = _extract_any_layout_from_op(operand.owner)
93+
if layout:
94+
break
95+
96+
if layout is None:
97+
return None
98+
99+
return ([layout for _ in op.operands], [layout for _ in op.results])
100+
101+
102+
for op in (
103+
arith.AddFOp,
104+
arith.ConstantOp,
105+
arith.MulFOp,
106+
):
107+
_add_layout_inference_rule(op, _infer_pointwise_op_layouts)
108+
109+
110+
def _layout_inference_should_process_op(op: ir.OpView) -> bool:
111+
"""Returns 'true' if the layout inference pass can skip the operation."""
112+
113+
def is_array(v: ir.Value):
114+
ty = v.type
115+
return ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty)
116+
117+
return any(map(is_array, itertools.chain(op.operands, op.results)))
118+
119+
120+
def _has_any_layout_set(op: ir.OpView) -> bool:
121+
return "in_layouts" in op.attributes or "out_layouts" in op.attributes
122+
123+
124+
class TraversalOrder(enum.Enum):
125+
"""Traversal orders with respect to the data flow for IR."""
126+
127+
FORWARD = 1
128+
BACKWARDS = 2
129+
130+
131+
def traverse_op(
132+
op: ir.OpView,
133+
callback: Callable[[ir.OpView], None],
134+
traversal_order: TraversalOrder = TraversalOrder.FORWARD,
135+
):
136+
"""Traverses the operation and applies the callback in the given order."""
137+
for region in op.operation.regions:
138+
for block in region:
139+
if traversal_order == TraversalOrder.FORWARD:
140+
ops_to_traverse = block
141+
else:
142+
ops_to_traverse = reversed(list(block))
143+
for block_op in ops_to_traverse:
144+
callback(block_op)
145+
callback(op)
146+
147+
148+
def infer_layout(module: ir.Module):
149+
def inference_step(op: ir.Operation):
150+
if not _layout_inference_should_process_op(op):
151+
return
152+
elif inference_rule := _layout_inference_rules.get(op.OPERATION_NAME, None): # pytype: disable=attribute-error
153+
pass
154+
else:
155+
raise NotImplementedError(f"Can not infer layout for {op}")
156+
157+
maybe_layouts = inference_rule(op)
158+
if maybe_layouts is None:
159+
return
160+
161+
_set_layout_attributes(op, *maybe_layouts)
162+
163+
# We run two passes over the module, in order to make sure that layouts
164+
# defined in the middle of the computation are propagated wherever they need
165+
# to be propagated. We start with a backwards (root-to-parameters) pass to
166+
# propagate the information as far up as possible, and then a forward pass
167+
# (parameters-to-root).
168+
#
169+
# Backwards pass
170+
for op in module.body:
171+
traverse_op(op, inference_step, TraversalOrder.BACKWARDS)
172+
173+
# Forward pass
174+
for op in module.body:
175+
traverse_op(op, inference_step, TraversalOrder.FORWARD)
176+
177+
# At this point, layouts have been propagated as far as they could be
178+
# propagated. However, it is possible for some operations to remain
179+
# unannotated---for example, if there were no annotations on any operation in
180+
# the module at the start of this function. We annotate all the remaining ops
181+
# that should be annotated with a strided fragmented layout.
182+
def set_default_layout(op: ir.OpView):
183+
layout = strided_fragmented_layout()
184+
if _layout_inference_should_process_op(op) and not _has_any_layout_set(op):
185+
_set_layout_attributes(
186+
op, [layout] * len(op.operands), [layout] * len(op.results))
187+
188+
for op in module.body:
189+
traverse_op(op, set_default_layout)

0 commit comments

Comments
 (0)