Skip to content

Commit 66f45d0

Browse files
[Mosaic GPU] Add WGMMA to the Mosaic GPU MLIR Dialect.
The op API is still in flux so I'm leaving some of the verification code untested. PiperOrigin-RevId: 705020066
1 parent cfdac00 commit 66f45d0

File tree

3 files changed

+353
-5
lines changed

3 files changed

+353
-5
lines changed

jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ limitations under the License.
4848
#include "absl/status/statusor.h"
4949
#include "absl/strings/str_cat.h"
5050
#include "absl/strings/string_view.h"
51+
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h"
52+
#include "mlir/include/mlir/IR/BuiltinTypes.h"
5153
#include "mlir/include/mlir/IR/Diagnostics.h"
5254
#include "tsl/platform/statusor.h"
5355

@@ -311,6 +313,184 @@ llvm::LogicalResult AsyncStoreOp::verify() {
311313
getSliceLengths(), getIndices().size());
312314
}
313315

316+
namespace {
317+
llvm::FailureOr<WGMMALayout> GetWgmmaLayout(mlir::Location loc,
318+
mlir::MemRefType type,
319+
absl::string_view name,
320+
SwizzlingMode swizzling_mode) {
321+
auto error = [loc](auto... params) {
322+
return emitError(loc, llvm::formatv(params...));
323+
};
324+
325+
auto [strides, offset] = mlir::getStridesAndOffset(type);
326+
327+
WGMMALayout layout = WGMMALayout::RowMajor;
328+
if (strides[3] == 1) {
329+
layout = WGMMALayout::RowMajor;
330+
} else if (strides[2] == 1) {
331+
layout = WGMMALayout::ColumnMajor;
332+
} else {
333+
return error(
334+
"At least one of the last two dimensions of `{0}` must have a "
335+
"stride of 1, but they do not: stride(dim 2)={1}, stride(dim 3)={2}",
336+
name, strides[2], strides[3]);
337+
}
338+
339+
auto shape = type.getShape();
340+
if (layout == WGMMALayout::RowMajor && strides[2] != shape[3]) {
341+
return error(
342+
"When `{0}` has row-major layout, the stride of dimension 2 (={1}) "
343+
"must be equal to size of dimension 3 (={2})",
344+
shape[3], strides[2], shape[3]);
345+
}
346+
347+
if (layout == WGMMALayout::ColumnMajor && strides[3] != shape[2]) {
348+
return error(
349+
"When `{0}` has column-major layout, the stride of dimension 3 (={1}) "
350+
"must be equal to size of dimension 2 (={2})",
351+
shape[2], strides[3], shape[2]);
352+
}
353+
354+
if (strides[1] != shape[2] * shape[3]) {
355+
return error(
356+
"Dimension 1 ` of `{0}` must have a stride equal to size of dimension "
357+
"2 times size of dimension 3 (={1}), but has {2}.",
358+
name, shape[2] * shape[3], strides[1]);
359+
}
360+
361+
return layout;
362+
}
363+
364+
// This is the size of the M dimension in all wgmma instructions. It is fixed,
365+
// unlike the K and N dimensions.
366+
constexpr int kWgmmaSizeM = 64;
367+
} // namespace
368+
369+
llvm::LogicalResult WGMMAOp::verify() {
370+
auto error = [this](auto... params) {
371+
return emitOpError(llvm::formatv(params...));
372+
};
373+
374+
auto a_shaped_type = mlir::cast<mlir::ShapedType>(getA().getType());
375+
mlir::Type element_type = a_shaped_type.getElementType();
376+
if (element_type != getB().getType().getElementType()) {
377+
return error("The `a` and `b` inputs must have the same element type.");
378+
}
379+
380+
auto b_shape = getB().getType().getShape();
381+
if (b_shape.size() != 4) {
382+
return error("The `b` input must have rank 4.");
383+
}
384+
385+
int element_bytewidth = element_type.getIntOrFloatBitWidth() / 8;
386+
int kn_tile = static_cast<int>(getSwizzle()) / element_bytewidth;
387+
388+
int64_t groups_k = b_shape[0];
389+
int64_t groups_n = b_shape[1];
390+
int64_t k_group_size = b_shape[2];
391+
int64_t n_group_size = b_shape[3];
392+
393+
// It might be possible to relax that requirement, in particular to allow
394+
// n_group_size to be smaller than kn_tile and use padding.
395+
if (n_group_size != kn_tile) {
396+
return error(
397+
"The n group size ({0}) must be equal to swizzle/element_bytewidth "
398+
"({1}).",
399+
n_group_size, kn_tile);
400+
}
401+
if (k_group_size != kn_tile) {
402+
return error(
403+
"The k group size ({0}) must be equal to swizzle/element_bytewidth "
404+
"({1}).",
405+
k_group_size, kn_tile);
406+
}
407+
408+
auto b_layout = GetWgmmaLayout(getLoc(), getB().getType(), "b", getSwizzle());
409+
if (failed(b_layout)) {
410+
return b_layout;
411+
}
412+
413+
int groups_m = 0;
414+
auto a_shape = a_shaped_type.getShape();
415+
if (auto a_memref = dyn_cast<mlir::MemRefType>(getA().getType())) {
416+
if (a_shape.size() != 4) {
417+
return error("When `a` is a memref, it must have rank 4.");
418+
}
419+
420+
groups_m = a_shape[0];
421+
422+
if (a_shape[1] != groups_k) {
423+
return error(
424+
"When `a` is a memref, dimension 1 ({0}) must be equal to groups_k "
425+
"which is `b`'s dimension 0 ({1}).",
426+
a_shape[1], groups_k);
427+
}
428+
429+
if (a_shape[2] != kWgmmaSizeM) {
430+
return error(
431+
"When `a` is a memref, dimension 2 ({0}) must be equal to {1}.",
432+
a_shape[2], kWgmmaSizeM);
433+
}
434+
435+
if (a_shape[3] != kn_tile) {
436+
return error(
437+
"When `a` is a memref, dimension 3 ({0}) must be equal to kn_tile.",
438+
a_shape[3]);
439+
}
440+
441+
auto a_layout = GetWgmmaLayout(getLoc(), a_memref, "a", getSwizzle());
442+
if (failed(a_layout)) {
443+
return a_layout;
444+
}
445+
if (*a_layout == WGMMALayout::ColumnMajor &&
446+
getSwizzle() != SwizzlingMode::k128ByteSwizzle) {
447+
// Not sure what the layout is like, since the tiles aren't square.
448+
return error(
449+
"When `a` is a memref and has a column-major layout, only a swizzle "
450+
"of 128 bytes is currently supported, but got {0}.");
451+
}
452+
} else {
453+
// a is a tensor in registers.
454+
if (!element_type.isBF16() && !element_type.isF16()) {
455+
return error(
456+
"When `a` is a tensor in registers, it must have element type bf16 "
457+
"or f16.");
458+
}
459+
if (a_shape.size() != 2) {
460+
return error("When `a` is a tensor in registers, it must have rank 2.");
461+
}
462+
if (a_shape[0] % kWgmmaSizeM) {
463+
return error(
464+
"When `a` is a tensor in registers, dimension 0 must be a multiple "
465+
"of {0}, but got {1}.",
466+
kWgmmaSizeM, a_shape[0]);
467+
}
468+
469+
groups_m = a_shape[0] / kWgmmaSizeM;
470+
471+
if (a_shape[1] != kn_tile * groups_k) {
472+
return error(
473+
"When `a` is a tensor in registers, dimension 1 must be equal to "
474+
"kn_tile * groups_k ({0}*{1}), but got {2}.",
475+
kn_tile, groups_k, a_shape[1]);
476+
}
477+
}
478+
479+
auto accShape = getAccumulator().getType().getShape();
480+
if (accShape.size() != 2) {
481+
return error("The accumulator must have rank 2.");
482+
}
483+
int expected_acc_0 = groups_m * kWgmmaSizeM;
484+
int expected_acc_1 = groups_n * n_group_size;
485+
if (accShape[0] != expected_acc_0 || accShape[1] != expected_acc_1) {
486+
return error(
487+
"Incorrect accumulator shape. Expected: [{0},{1}], but got [{2},{3}].",
488+
expected_acc_0, expected_acc_1, accShape[0], accShape[1]);
489+
}
490+
491+
return llvm::success();
492+
}
493+
314494
void MosaicGPUDialect::initialize() {
315495
addTypes<
316496
#define GET_TYPEDEF_LIST

jaxlib/mosaic/dialect/gpu/mosaic_gpu.td

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,18 @@ def MosaicGPU_DimensionAttr : EnumAttr<MosaicGPU_Dialect, MosaicGPU_Dimension, "
120120
def MosaicGPU_SwizzlingMode : I32EnumAttr<"SwizzlingMode",
121121
"What swizzling to use for a memory access.",
122122
[
123-
I32EnumAttrCase<"kNoSwizzle", 0, "none">,
124-
I32EnumAttrCase<"k32ByteSwizzle", 1, "32">,
125-
I32EnumAttrCase<"k64ByteSwizzle", 2, "64">,
126-
I32EnumAttrCase<"k128ByteSwizzle", 3, "128">
123+
I32EnumAttrCase<"kNoSwizzle", 16, "swizzle_none">,
124+
I32EnumAttrCase<"k32ByteSwizzle", 32, "swizzle_32">,
125+
I32EnumAttrCase<"k64ByteSwizzle", 64, "swizzle_64">,
126+
I32EnumAttrCase<"k128ByteSwizzle", 128, "swizzle_128">
127127
]>{
128128
let cppNamespace = "::mosaic_gpu";
129129
let genSpecializedAttr = 0;
130130
}
131131

132-
def MosaicGPU_SwizzlingModeAttr : EnumAttr<MosaicGPU_Dialect, MosaicGPU_SwizzlingMode, "swizzle">;
132+
def MosaicGPU_SwizzlingModeAttr : EnumAttr<MosaicGPU_Dialect, MosaicGPU_SwizzlingMode, "swizzle"> {
133+
let assemblyFormat = "`<` $value `>`";
134+
}
133135

134136
def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> {
135137
let parameters = (ins Variadic<I64>:$tiling);
@@ -276,4 +278,78 @@ def MosaicGPU_AsyncStoreOp : Op<MosaicGPU_Dialect, "async_store",
276278
let hasVerifier = 1;
277279
}
278280

281+
def MosaicGPU_WGMMASupportedType : AnyTypeOf<[F16, BF16, F32],
282+
"A type supported by the WGMMA operation">;
283+
284+
def MosaicGPU_WGMMALayout :
285+
I32EnumAttr<"WGMMALayout", "The layout of the tiles of a WGMMA operation", [
286+
I32EnumAttrCase<"RowMajor", 0>,
287+
I32EnumAttrCase<"ColumnMajor", 1>
288+
]> {
289+
let cppNamespace = "::mosaic_gpu";
290+
let genSpecializedAttr = 0;
291+
}
292+
293+
def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", []> {
294+
let summary = "Multiply two matrices asyncronously using warpgroup level matrix multiply operations.";
295+
let description = [{
296+
Schedules WGMMA operations that perform the following matrix multiple and
297+
accumulate:
298+
299+
accumulator = a * b + accumulator
300+
301+
This operation supports larger inputs than the PTX-level WGMMA operation
302+
and will schedule as many PTX-level WGMMA operations as needed to
303+
accomplish the calculation. The `b` matrix, and optionally `a`, needs to be
304+
provided in a 4-dimensional form, where the two minor-most dimensions
305+
express the tile (group) size and the two major-most dimensions represent
306+
the total number of tiles in each direction.
307+
308+
The inputs should have the following shapes:
309+
- If `a` is in shared memory:
310+
- a: [groups_m, groups_k, 64, k]
311+
- If `a` is in registers:
312+
- a: [groups_m * 64, groups_k * k]
313+
- b: [groups_k, groups_n, k, k]
314+
- accumulator: [groups_m * 64, groups_n * k]
315+
Where:
316+
- `k == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.)
317+
318+
The `accumulator` is always in registers and `b` is always in shared memory.
319+
The last two dimensions of any input in shared memory may be physically
320+
transposed in memory. This is inferred from the strides of the provided
321+
memrefs. `a` and `b` must have the same element type and when `a` is in
322+
registers only F16 or BF16 are supported.
323+
324+
The `accumulator` must be a tensor with a FragmentedLayout. The WGMMA
325+
operation will be executed in the async proxy and any inputs in
326+
registers need to be synchronized with a memory fence.
327+
328+
Usually `a` is read from shared memory if it is used directly in the WGMMA
329+
operation. If `a` needs to be transfromed before it is used in the WGMMA
330+
operation, it may be more convenient to read it directly form registers.
331+
This avoids the need to store the data and wait for a fence.
332+
}];
333+
334+
let arguments = (ins
335+
TensorOf<[MosaicGPU_WGMMASupportedType]>:$accumulator,
336+
AnyTypeOf<[
337+
MemRefOf<[MosaicGPU_WGMMASupportedType]>,
338+
TensorOf<[MosaicGPU_WGMMASupportedType]>]>:$a,
339+
MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b,
340+
341+
// Attributes
342+
DefaultValuedAttr<MosaicGPU_SwizzlingModeAttr, "SwizzlingMode::k128ByteSwizzle">:$swizzle
343+
);
344+
345+
let assemblyFormat = [{
346+
`accumulator` `(` $accumulator `:` type($accumulator) `)`
347+
`a` `(` $a `:` type($a) `)`
348+
`b` `(` $b `:` type($b) `)`
349+
attr-dict
350+
}];
351+
352+
let hasVerifier = 1;
353+
}
354+
279355
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_

tests/mosaic/gpu_dialect_test.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,98 @@ def test_async_store_op_slice_lengths_size_must_match_source_rank(self):
475475
):
476476
self.module.operation.verify()
477477

478+
def test_wgmma_types_match(self):
479+
with ir.InsertionPoint(self.module.body):
480+
func.FuncOp.from_py_func(
481+
ir.RankedTensorType.get([128, 160], ir.BF16Type.get()),
482+
ir.MemRefType.get([2, 4, 64, 32], ir.F16Type.get()),
483+
ir.MemRefType.get([4, 5, 32, 32], ir.BF16Type.get()),
484+
name="wgmma",
485+
)(
486+
lambda accumulator, a, b: mgpu.wgmma(
487+
accumulator,
488+
a,
489+
b,
490+
swizzle=ir.Attribute.parse("#mosaic_gpu.swizzle<swizzle_64>"),
491+
)
492+
)
493+
494+
with self.assertRaisesRegex(
495+
ir.MLIRError,
496+
"The `a` and `b` inputs must have the same element type.",
497+
):
498+
self.module.operation.verify()
499+
500+
def test_wgmma_b_rank_is_4(self):
501+
with ir.InsertionPoint(self.module.body):
502+
func.FuncOp.from_py_func(
503+
ir.RankedTensorType.get([128, 160], ir.BF16Type.get()),
504+
ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()),
505+
ir.MemRefType.get([5, 32, 32], ir.BF16Type.get()),
506+
name="wgmma",
507+
)(
508+
lambda accumulator, a, b: mgpu.wgmma(
509+
accumulator,
510+
a,
511+
b,
512+
swizzle=ir.Attribute.parse("#mosaic_gpu.swizzle<swizzle_64>"),
513+
)
514+
)
515+
516+
with self.assertRaisesRegex(
517+
ir.MLIRError,
518+
"The `b` input must have rank 4.",
519+
):
520+
self.module.operation.verify()
521+
522+
def test_wgmma_b_shape_dim_3(self):
523+
with ir.InsertionPoint(self.module.body):
524+
func.FuncOp.from_py_func(
525+
ir.RankedTensorType.get([128, 160], ir.BF16Type.get()),
526+
ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()),
527+
ir.MemRefType.get([4, 5, 32, 16], ir.BF16Type.get()),
528+
name="wgmma",
529+
)(
530+
lambda accumulator, a, b: mgpu.wgmma(
531+
accumulator,
532+
a,
533+
b,
534+
swizzle=ir.Attribute.parse("#mosaic_gpu.swizzle<swizzle_64>"),
535+
)
536+
)
537+
538+
with self.assertRaisesRegex(
539+
ir.MLIRError,
540+
r"The n group size \(16\) must be equal to swizzle/element_bytewidth "
541+
r"\(32\)",
542+
):
543+
self.module.operation.verify()
544+
545+
def test_wgmma_b_shape_dim_2(self):
546+
with ir.InsertionPoint(self.module.body):
547+
func.FuncOp.from_py_func(
548+
ir.RankedTensorType.get([128, 160], ir.BF16Type.get()),
549+
ir.MemRefType.get([2, 4, 64, 32], ir.BF16Type.get()),
550+
ir.MemRefType.get([4, 5, 64, 32], ir.BF16Type.get()),
551+
name="wgmma",
552+
)(
553+
lambda accumulator, a, b: mgpu.wgmma(
554+
accumulator,
555+
a,
556+
b,
557+
swizzle=ir.Attribute.parse("#mosaic_gpu.swizzle<swizzle_64>"),
558+
)
559+
)
560+
561+
with self.assertRaisesRegex(
562+
ir.MLIRError,
563+
r"The k group size \(64\) must be equal to swizzle/element_bytewidth "
564+
r"\(32\)",
565+
):
566+
self.module.operation.verify()
567+
568+
# TODO(b/381371456): Add tests for the other WGMMA inputs.
569+
478570

479571
class DialectLoweringTest(DialectTest):
480572

0 commit comments

Comments
 (0)