Skip to content

Commit 9a686e0

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add initial transform inference rules for vector.{load,store}.
PiperOrigin-RevId: 737703568
1 parent de9ad6b commit 9a686e0

File tree

2 files changed

+243
-0
lines changed

2 files changed

+243
-0
lines changed

jax/experimental/mosaic/gpu/transform_inference.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@
2525

2626
from jax._src.lib import mosaic_gpu_dialect as mgpu
2727
from jax._src.lib.mlir import ir
28+
from jax._src.lib.mlir.dialects import arith
29+
from jax._src.lib.mlir.dialects import vector
2830

31+
from . import fragmented_array as fa
2932
from . import inference_utils
33+
from . import layouts as layouts_lib
3034
from . import utils
3135

3236
# mypy: ignore-errors
@@ -40,6 +44,7 @@ def _add_transform_inference_rule(
4044
op: type[ir.OpView], rule: TransformInferenceRule
4145
):
4246
_transform_inference_rules[op.OPERATION_NAME] = rule # pytype: disable=attribute-error
47+
return rule
4348

4449

4550
def _set_transform_attributes(
@@ -110,6 +115,60 @@ def _infer_async_load_transforms(op: mgpu.AsyncLoadOp) -> OptionalTransforms:
110115
return None if in_transforms is None else ([in_transforms], [])
111116

112117

118+
@partial(_add_transform_inference_rule, vector.LoadOp)
119+
@partial(_add_transform_inference_rule, vector.StoreOp)
120+
def _infer_vector_load_store_transforms(
121+
op: vector.LoadOp | vector.StoreOp,
122+
) -> OptionalTransforms:
123+
for i in op.indices:
124+
index_defining_op = i.owner.opview
125+
if (
126+
not isinstance(index_defining_op, arith.ConstantOp)
127+
or index_defining_op.literal_value != 0
128+
):
129+
# TODO(bchetioui): handle slicing.
130+
raise NotImplementedError(
131+
f"Only constants with value 0 are supported as indices for {op}"
132+
)
133+
134+
if isinstance(op, vector.LoadOp):
135+
[layout_attr] = inference_utils.out_layouts(op)
136+
else:
137+
assert isinstance(op, vector.StoreOp)
138+
[layout_attr] = inference_utils.in_layouts(op)
139+
140+
layout = layouts_lib.from_layout_attr(layout_attr)
141+
transforms = inference_utils.value_transforms(op.base)
142+
143+
if layout == fa.WGMMA_LAYOUT:
144+
layout_transforms = infer_transforms_for_wgmma_ref(
145+
ir.MemRefType(op.base.type)
146+
)
147+
elif (isinstance(layout, fa.WGStridedFragLayout) or
148+
isinstance(layout, fa.WGSplatFragLayout)):
149+
layout_transforms = None
150+
else:
151+
raise NotImplementedError(
152+
f"Got layout {layout} which is not yet supported"
153+
)
154+
155+
if transforms is not None and layout_transforms is not None:
156+
if transforms != layout_transforms:
157+
raise NotImplementedError(
158+
f"Conflicting transforms for {op.base} in {op}: "
159+
f"{transforms} != {layout_transforms}."
160+
)
161+
return [transforms], []
162+
163+
if transforms is not None:
164+
return [transforms], []
165+
166+
if layout_transforms is not None:
167+
return [layout_transforms], []
168+
169+
return None
170+
171+
113172
def _should_have_transforms(op: ir.OpView) -> bool:
114173
"""Returns 'True' if the operation should be assigned in/out transforms."""
115174
return any(

tests/mosaic/gpu_transform_inference_test.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@
2525
from jax._src.lib.mlir import ir
2626
from jax._src.lib.mlir.dialects import arith
2727
from jax._src.lib.mlir.dialects import func
28+
from jax._src.lib.mlir.dialects import vector
2829
import jax.experimental.mosaic.gpu as mgpu
30+
from jax.experimental.mosaic.gpu import fragmented_array as fa
2931
from jax.experimental.mosaic.gpu import inference_utils
32+
from jax.experimental.mosaic.gpu import layouts as layouts_lib
3033
import numpy as np
3134

3235

@@ -162,6 +165,187 @@ def body(gmem_ref, smem_ref):
162165
)
163166
self.assertEmpty(inference_utils.out_transforms(async_store_op))
164167

168+
def test_infer_transforms_for_vector_load_op_derives_from_destination(self):
169+
vector_load_op = None
170+
shape = (64, 64)
171+
elt_ty = ir.BF16Type.get()
172+
173+
def body(smem_ref):
174+
nonlocal vector_load_op
175+
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
176+
vector_load_op = vector.LoadOp(
177+
ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape)
178+
)
179+
180+
with ir.InsertionPoint(self.module.body):
181+
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
182+
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
183+
func.FuncOp.from_py_func(smem_ty)(body)
184+
185+
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
186+
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
187+
)
188+
189+
mgpu.infer_transforms(self.module)
190+
191+
expected_transforms = ir.ArrayAttr.get([
192+
mgpu.dialect.TileTransformAttr.get((8, 64)),
193+
mgpu.dialect.SwizzleTransformAttr.get(128),
194+
])
195+
196+
self.assertSequenceEqual(
197+
inference_utils.in_transforms(vector_load_op), [expected_transforms]
198+
)
199+
self.assertEmpty(inference_utils.out_transforms(vector_load_op))
200+
201+
def test_infer_transforms_for_vector_load_op_derives_from_source(self):
202+
vector_load_op = None
203+
shape = (64, 64)
204+
elt_ty = ir.BF16Type.get()
205+
206+
def body(smem_ref):
207+
nonlocal vector_load_op
208+
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
209+
vector_load_op = vector.LoadOp(
210+
ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape)
211+
)
212+
213+
with ir.InsertionPoint(self.module.body):
214+
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
215+
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
216+
f = func.FuncOp.from_py_func(smem_ty)(body).func_op
217+
218+
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
219+
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
220+
)
221+
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
222+
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
223+
224+
mgpu.infer_transforms(self.module)
225+
226+
self.assertSequenceEqual(
227+
inference_utils.in_transforms(vector_load_op), [transforms]
228+
)
229+
self.assertEmpty(inference_utils.out_transforms(vector_load_op))
230+
231+
def test_infer_transforms_for_vector_load_op_raises_on_mismatches(self):
232+
vector_load_op = None
233+
shape = (64, 64)
234+
elt_ty = ir.BF16Type.get()
235+
236+
def body(smem_ref):
237+
nonlocal vector_load_op
238+
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
239+
vector_load_op = vector.LoadOp(
240+
ir.VectorType.get(shape, elt_ty), smem_ref, [zero] * len(shape)
241+
)
242+
243+
with ir.InsertionPoint(self.module.body):
244+
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
245+
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
246+
f = func.FuncOp.from_py_func(smem_ty)(body).func_op
247+
248+
vector_load_op.attributes["out_layouts"] = ir.ArrayAttr.get(
249+
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
250+
)
251+
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
252+
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
253+
254+
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
255+
mgpu.infer_transforms(self.module)
256+
257+
def test_infer_transforms_for_vector_store_op_derives_from_destination(self):
258+
vector_store_op = None
259+
shape = (64, 64)
260+
elt_ty = ir.BF16Type.get()
261+
262+
def body(smem_ref, value_to_store):
263+
nonlocal vector_store_op
264+
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
265+
vector_store_op = vector.StoreOp(
266+
value_to_store, smem_ref, [zero] * len(shape)
267+
)
268+
269+
with ir.InsertionPoint(self.module.body):
270+
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
271+
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
272+
value_ty = ir.VectorType.get(shape, elt_ty)
273+
func.FuncOp.from_py_func(smem_ty, value_ty)(body)
274+
275+
vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get(
276+
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
277+
)
278+
279+
mgpu.infer_transforms(self.module)
280+
281+
expected_transforms = ir.ArrayAttr.get([
282+
mgpu.dialect.TileTransformAttr.get((8, 64)),
283+
mgpu.dialect.SwizzleTransformAttr.get(128),
284+
])
285+
286+
self.assertSequenceEqual(
287+
inference_utils.in_transforms(vector_store_op), [expected_transforms]
288+
)
289+
self.assertEmpty(inference_utils.out_transforms(vector_store_op))
290+
291+
def test_infer_transforms_for_vector_store_op_derives_from_source(self):
292+
vector_store_op = None
293+
shape = (64, 64)
294+
elt_ty = ir.BF16Type.get()
295+
296+
def body(smem_ref, value_to_store):
297+
nonlocal vector_store_op
298+
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
299+
vector_store_op = vector.StoreOp(
300+
value_to_store, smem_ref, [zero] * len(shape)
301+
)
302+
303+
with ir.InsertionPoint(self.module.body):
304+
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
305+
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
306+
value_ty = ir.VectorType.get(shape, elt_ty)
307+
f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op
308+
309+
vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get(
310+
[layouts_lib.to_layout_attr(fa.WGStridedFragLayout(shape, vec_size=4))]
311+
)
312+
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
313+
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
314+
315+
mgpu.infer_transforms(self.module)
316+
317+
self.assertSequenceEqual(
318+
inference_utils.in_transforms(vector_store_op), [transforms]
319+
)
320+
self.assertEmpty(inference_utils.out_transforms(vector_store_op))
321+
322+
def test_infer_transforms_for_vector_store_op_raises_on_mismatches(self):
323+
vector_store_op = None
324+
shape = (64, 64)
325+
elt_ty = ir.BF16Type.get()
326+
327+
def body(smem_ref, value_to_store):
328+
nonlocal vector_store_op
329+
zero = arith.constant(ir.IntegerType.get_signless(32), 0)
330+
vector_store_op = vector.StoreOp(
331+
value_to_store, smem_ref, [zero] * len(shape)
332+
)
333+
334+
with ir.InsertionPoint(self.module.body):
335+
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
336+
smem_ty = ir.MemRefType.get(shape, elt_ty, memory_space=smem)
337+
value_ty = ir.VectorType.get(shape, elt_ty)
338+
f = func.FuncOp.from_py_func(smem_ty, value_ty)(body).func_op
339+
340+
vector_store_op.attributes["in_layouts"] = ir.ArrayAttr.get(
341+
[layouts_lib.to_layout_attr(fa.WGMMA_LAYOUT)]
342+
)
343+
transforms = ir.ArrayAttr.get([mgpu.dialect.TileTransformAttr.get((8, 64))])
344+
f.attributes["in_transforms"] = ir.ArrayAttr.get([transforms])
345+
346+
with self.assertRaisesRegex(NotImplementedError, "Conflicting transforms"):
347+
mgpu.infer_transforms(self.module)
348+
165349

166350
if __name__ == "__main__":
167351
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)