@@ -48,6 +48,8 @@ limitations under the License.
4848#include " absl/status/statusor.h"
4949#include " absl/strings/str_cat.h"
5050#include " absl/strings/string_view.h"
51+ #include " mlir/include/mlir/IR/BuiltinTypeInterfaces.h"
52+ #include " mlir/include/mlir/IR/BuiltinTypes.h"
5153#include " mlir/include/mlir/IR/Diagnostics.h"
5254#include " tsl/platform/statusor.h"
5355
@@ -311,6 +313,184 @@ llvm::LogicalResult AsyncStoreOp::verify() {
311313 getSliceLengths (), getIndices ().size ());
312314}
313315
316+ namespace {
317+ llvm::FailureOr<WGMMALayout> GetWgmmaLayout (mlir::Location loc,
318+ mlir::MemRefType type,
319+ absl::string_view name,
320+ SwizzlingMode swizzling_mode) {
321+ auto error = [loc](auto ... params) {
322+ return emitError (loc, llvm::formatv (params...));
323+ };
324+
325+ auto [strides, offset] = mlir::getStridesAndOffset (type);
326+
327+ WGMMALayout layout = WGMMALayout::RowMajor;
328+ if (strides[3 ] == 1 ) {
329+ layout = WGMMALayout::RowMajor;
330+ } else if (strides[2 ] == 1 ) {
331+ layout = WGMMALayout::ColumnMajor;
332+ } else {
333+ return error (
334+ " At least one of the last two dimensions of `{0}` must have a "
335+ " stride of 1, but they do not: stride(dim 2)={1}, stride(dim 3)={2}" ,
336+ name, strides[2 ], strides[3 ]);
337+ }
338+
339+ auto shape = type.getShape ();
340+ if (layout == WGMMALayout::RowMajor && strides[2 ] != shape[3 ]) {
341+ return error (
342+ " When `{0}` has row-major layout, the stride of dimension 2 (={1}) "
343+ " must be equal to size of dimension 3 (={2})" ,
344+ shape[3 ], strides[2 ], shape[3 ]);
345+ }
346+
347+ if (layout == WGMMALayout::ColumnMajor && strides[3 ] != shape[2 ]) {
348+ return error (
349+ " When `{0}` has column-major layout, the stride of dimension 3 (={1}) "
350+ " must be equal to size of dimension 2 (={2})" ,
351+ shape[2 ], strides[3 ], shape[2 ]);
352+ }
353+
354+ if (strides[1 ] != shape[2 ] * shape[3 ]) {
355+ return error (
356+ " Dimension 1 ` of `{0}` must have a stride equal to size of dimension "
357+ " 2 times size of dimension 3 (={1}), but has {2}." ,
358+ name, shape[2 ] * shape[3 ], strides[1 ]);
359+ }
360+
361+ return layout;
362+ }
363+
364+ // This is the size of the M dimension in all wgmma instructions. It is fixed,
365+ // unlike the K and N dimensions.
366+ constexpr int kWgmmaSizeM = 64 ;
367+ } // namespace
368+
369+ llvm::LogicalResult WGMMAOp::verify () {
370+ auto error = [this ](auto ... params) {
371+ return emitOpError (llvm::formatv (params...));
372+ };
373+
374+ auto a_shaped_type = mlir::cast<mlir::ShapedType>(getA ().getType ());
375+ mlir::Type element_type = a_shaped_type.getElementType ();
376+ if (element_type != getB ().getType ().getElementType ()) {
377+ return error (" The `a` and `b` inputs must have the same element type." );
378+ }
379+
380+ auto b_shape = getB ().getType ().getShape ();
381+ if (b_shape.size () != 4 ) {
382+ return error (" The `b` input must have rank 4." );
383+ }
384+
385+ int element_bytewidth = element_type.getIntOrFloatBitWidth () / 8 ;
386+ int kn_tile = static_cast <int >(getSwizzle ()) / element_bytewidth;
387+
388+ int64_t groups_k = b_shape[0 ];
389+ int64_t groups_n = b_shape[1 ];
390+ int64_t k_group_size = b_shape[2 ];
391+ int64_t n_group_size = b_shape[3 ];
392+
393+ // It might be possible to relax that requirement, in particular to allow
394+ // n_group_size to be smaller than kn_tile and use padding.
395+ if (n_group_size != kn_tile) {
396+ return error (
397+ " The n group size ({0}) must be equal to swizzle/element_bytewidth "
398+ " ({1})." ,
399+ n_group_size, kn_tile);
400+ }
401+ if (k_group_size != kn_tile) {
402+ return error (
403+ " The k group size ({0}) must be equal to swizzle/element_bytewidth "
404+ " ({1})." ,
405+ k_group_size, kn_tile);
406+ }
407+
408+ auto b_layout = GetWgmmaLayout (getLoc (), getB ().getType (), " b" , getSwizzle ());
409+ if (failed (b_layout)) {
410+ return b_layout;
411+ }
412+
413+ int groups_m = 0 ;
414+ auto a_shape = a_shaped_type.getShape ();
415+ if (auto a_memref = dyn_cast<mlir::MemRefType>(getA ().getType ())) {
416+ if (a_shape.size () != 4 ) {
417+ return error (" When `a` is a memref, it must have rank 4." );
418+ }
419+
420+ groups_m = a_shape[0 ];
421+
422+ if (a_shape[1 ] != groups_k) {
423+ return error (
424+ " When `a` is a memref, dimension 1 ({0}) must be equal to groups_k "
425+ " which is `b`'s dimension 0 ({1})." ,
426+ a_shape[1 ], groups_k);
427+ }
428+
429+ if (a_shape[2 ] != kWgmmaSizeM ) {
430+ return error (
431+ " When `a` is a memref, dimension 2 ({0}) must be equal to {1}." ,
432+ a_shape[2 ], kWgmmaSizeM );
433+ }
434+
435+ if (a_shape[3 ] != kn_tile) {
436+ return error (
437+ " When `a` is a memref, dimension 3 ({0}) must be equal to kn_tile." ,
438+ a_shape[3 ]);
439+ }
440+
441+ auto a_layout = GetWgmmaLayout (getLoc (), a_memref, " a" , getSwizzle ());
442+ if (failed (a_layout)) {
443+ return a_layout;
444+ }
445+ if (*a_layout == WGMMALayout::ColumnMajor &&
446+ getSwizzle () != SwizzlingMode::k128ByteSwizzle) {
447+ // Not sure what the layout is like, since the tiles aren't square.
448+ return error (
449+ " When `a` is a memref and has a column-major layout, only a swizzle "
450+ " of 128 bytes is currently supported, but got {0}." );
451+ }
452+ } else {
453+ // a is a tensor in registers.
454+ if (!element_type.isBF16 () && !element_type.isF16 ()) {
455+ return error (
456+ " When `a` is a tensor in registers, it must have element type bf16 "
457+ " or f16." );
458+ }
459+ if (a_shape.size () != 2 ) {
460+ return error (" When `a` is a tensor in registers, it must have rank 2." );
461+ }
462+ if (a_shape[0 ] % kWgmmaSizeM ) {
463+ return error (
464+ " When `a` is a tensor in registers, dimension 0 must be a multiple "
465+ " of {0}, but got {1}." ,
466+ kWgmmaSizeM , a_shape[0 ]);
467+ }
468+
469+ groups_m = a_shape[0 ] / kWgmmaSizeM ;
470+
471+ if (a_shape[1 ] != kn_tile * groups_k) {
472+ return error (
473+ " When `a` is a tensor in registers, dimension 1 must be equal to "
474+ " kn_tile * groups_k ({0}*{1}), but got {2}." ,
475+ kn_tile, groups_k, a_shape[1 ]);
476+ }
477+ }
478+
479+ auto accShape = getAccumulator ().getType ().getShape ();
480+ if (accShape.size () != 2 ) {
481+ return error (" The accumulator must have rank 2." );
482+ }
483+ int expected_acc_0 = groups_m * kWgmmaSizeM ;
484+ int expected_acc_1 = groups_n * n_group_size;
485+ if (accShape[0 ] != expected_acc_0 || accShape[1 ] != expected_acc_1) {
486+ return error (
487+ " Incorrect accumulator shape. Expected: [{0},{1}], but got [{2},{3}]." ,
488+ expected_acc_0, expected_acc_1, accShape[0 ], accShape[1 ]);
489+ }
490+
491+ return llvm::success ();
492+ }
493+
314494void MosaicGPUDialect::initialize () {
315495 addTypes<
316496#define GET_TYPEDEF_LIST
0 commit comments