Skip to content

Commit 3915f4a

Browse files
bchetiouiGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Commit to using Vectors everywhere (and no Tensors).
PiperOrigin-RevId: 707912637
1 parent c4fae4a commit 3915f4a

File tree

5 files changed

+13
-20
lines changed

5 files changed

+13
-20
lines changed

jax/experimental/mosaic/gpu/layout_inference.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,16 +164,14 @@ def inference_step(op: ir.Operation):
164164
def to_default_layout(ty: ir.Type) -> ir.Attribute | None:
165165
if ir.VectorType.isinstance(ty):
166166
layout = WGStridedFragLayout.from_shaped_type(ty)
167-
elif ir.RankedTensorType.isinstance(ty):
168-
layout = WGStridedFragLayout.from_shaped_type(ty)
169167
else:
170168
return None
171169
return to_strided_fragmented_layout_attr(layout)
172170

173171
def set_default_layout(op: ir.OpView):
174172
if should_have_layout(op) and not has_any_layout_set(op):
175173
# TODO(bchetioui): consistently set layouts only for supported argument
176-
# types (i.e. skip non-vector/tensor typed arguments.)
174+
# types (i.e. skip non-vector typed arguments.)
177175
in_layouts = []
178176
for operand in op.operands:
179177
if (layout := to_default_layout(operand.type)) is not None:

jax/experimental/mosaic/gpu/layouts.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,7 @@ def to_splat_fragmented_layout_attr(layout: WGSplatFragLayout) -> ir.Attribute:
7272
def should_have_layout(op: ir.OpView) -> bool:
7373
"""Returns 'true' if the operation should be assigned a layout."""
7474

75-
def is_array(v: ir.Value):
76-
ty = v.type
77-
return ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty)
78-
75+
is_array = lambda v: ir.VectorType.isinstance(v.type)
7976
return any(map(is_array, itertools.chain(op.operands, op.results))) # type: ignore
8077

8178

jaxlib/mosaic/dialect/gpu/mosaic_gpu.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", []> {
327327
memrefs. `a` and `b` must have the same element type and when `a` is in
328328
registers only F16 or BF16 are supported.
329329

330-
The `accumulator` must be a tensor with a FragmentedLayout. The WGMMA
330+
The `accumulator` must be a vector with a FragmentedLayout. The WGMMA
331331
operation will be executed in the async proxy and any inputs in
332332
registers need to be synchronized with a memory fence.
333333

@@ -338,10 +338,10 @@ def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", []> {
338338
}];
339339

340340
let arguments = (ins
341-
TensorOf<[MosaicGPU_WGMMASupportedType]>:$accumulator,
341+
VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>:$accumulator,
342342
AnyTypeOf<[
343343
MemRefOf<[MosaicGPU_WGMMASupportedType]>,
344-
TensorOf<[MosaicGPU_WGMMASupportedType]>]>:$a,
344+
VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>]>:$a,
345345
MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b,
346346

347347
// Attributes

tests/mosaic/gpu_dialect_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def test_async_store_op_slice_lengths_size_must_match_source_rank(self):
487487
def test_wgmma_types_match(self):
488488
with ir.InsertionPoint(self.module.body):
489489
func.FuncOp.from_py_func(
490-
ir.RankedTensorType.get([128, 160], ir.BF16Type.get()),
490+
ir.VectorType.get([128, 160], ir.BF16Type.get()),
491491
ir.MemRefType.get([2, 4, 64, 32], ir.F16Type.get()),
492492
ir.MemRefType.get([4, 5, 32, 32], ir.BF16Type.get()),
493493
name="wgmma",
@@ -509,7 +509,7 @@ def test_wgmma_types_match(self):
509509
def test_wgmma_b_rank_is_4(self):
510510
with ir.InsertionPoint(self.module.body):
511511
func.FuncOp.from_py_func(
512-
ir.RankedTensorType.get([128, 160], ir.BF16Type.get()),
512+
ir.VectorType.get([128, 160], ir.BF16Type.get()),
513513
ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()),
514514
ir.MemRefType.get([5, 32, 32], ir.BF16Type.get()),
515515
name="wgmma",
@@ -531,7 +531,7 @@ def test_wgmma_b_rank_is_4(self):
531531
def test_wgmma_b_shape_dim_3(self):
532532
with ir.InsertionPoint(self.module.body):
533533
func.FuncOp.from_py_func(
534-
ir.RankedTensorType.get([128, 160], ir.BF16Type.get()),
534+
ir.VectorType.get([128, 160], ir.BF16Type.get()),
535535
ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()),
536536
ir.MemRefType.get([4, 5, 32, 16], ir.BF16Type.get()),
537537
name="wgmma",
@@ -554,7 +554,7 @@ def test_wgmma_b_shape_dim_3(self):
554554
def test_wgmma_b_shape_dim_2(self):
555555
with ir.InsertionPoint(self.module.body):
556556
func.FuncOp.from_py_func(
557-
ir.RankedTensorType.get([128, 160], ir.BF16Type.get()),
557+
ir.VectorType.get([128, 160], ir.BF16Type.get()),
558558
ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()),
559559
ir.MemRefType.get([4, 5, 64, 32], ir.BF16Type.get()),
560560
name="wgmma",

tests/mosaic/gpu_layout_inference_test.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,12 @@ def setUp(self):
4545
self.enter_context(ir.Location.unknown())
4646
self.module = ir.Module.create()
4747

48-
@parameterized.parameters(ir.RankedTensorType, ir.VectorType)
49-
def test_infer_layout_default(self, type_constructor):
48+
def test_infer_layout_default(self):
5049
shape = (16, 8)
5150
elt_type = ir.BF16Type.get()
5251

5352
with ir.InsertionPoint(self.module.body):
54-
ab_type = type_constructor.get(shape, elt_type)
53+
ab_type = ir.VectorType.get(shape, elt_type)
5554
const_zero = ir.FloatAttr.get(elt_type, 0)
5655
const_one = ir.FloatAttr.get(elt_type, 1)
5756
a = arith.ConstantOp(
@@ -80,13 +79,12 @@ def test_infer_layout_default(self, type_constructor):
8079
op.attributes["out_layouts"], [layout] * len(op.results)
8180
)
8281

83-
@parameterized.parameters(ir.RankedTensorType, ir.VectorType)
84-
def test_infer_layout_for_pointwise_op(self, type_constructor):
82+
def test_infer_layout_for_pointwise_op(self):
8583
shape = (4, 8)
8684
elt_type = ir.BF16Type.get()
8785

8886
with ir.InsertionPoint(self.module.body):
89-
ab_type = type_constructor.get(shape, elt_type)
87+
ab_type = ir.VectorType.get(shape, elt_type)
9088
const_zero = ir.FloatAttr.get(elt_type, 0)
9189
const_one = ir.FloatAttr.get(elt_type, 1)
9290
a = arith.ConstantOp(

0 commit comments

Comments
 (0)