Skip to content

Commit 3ba7ea8

Browse files
PR Review Round 1
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent c43bc26 commit 3ba7ea8

File tree

3 files changed

+34
-42
lines changed

3 files changed

+34
-42
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ def AMDGPU_GatherToLDSOp :
804804
TypeAttr:$transferType
805805
)>,
806806
Results<(outs)> {
807-
let summary = "MLIR wrapper for CDNA mfma instructions";
807+
let summary = "MLIR wrapper for CDNA Gather to LDS instructions";
808808
let description = [{
809809
The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
810810

@@ -845,7 +845,7 @@ def AMDGPU_ScaledMFMAOp :
845845
I32Attr:$opselA,
846846
I32Attr:$opselB)>,
847847
Results<(outs MFMAOutTypes: $destD)> {
848-
let summary = "MLIR wrapper for CDNA mfma instructions";
848+
let summary = "MLIR wrapper for CDNA scaled mfma instructions";
849849
let description = [{
850850
The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
851851
for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
@@ -857,7 +857,7 @@ def AMDGPU_ScaledMFMAOp :
857857

858858
Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
859859
intrinsics that take an integer type of width `4K`. For example,
860-
one can provide a vector<4xi8> as an argument to an MFMA instruction that
860+
one can provide a `vector<4xi8>` as an argument to an MFMA instruction that
861861
logically takes 4 i8s but whose intrinsics are specified to take an i32.
862862
In these cases, the bytes in the vector will be concatenated in little-endian
863863
order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
@@ -868,7 +868,7 @@ def AMDGPU_ScaledMFMAOp :
868868
size.
869869
- `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
870870
are omitted from this wrapper.
871-
- The negateA, negateB, and negateC flags in `amdgpu.mfma` are only supported for
871+
- The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported for
872872
double-precision operations on gfx94x and so are not included here.
873873
}];
874874
let assemblyFormat = [{

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include "llvm/ADT/STLExtras.h"
2525
#include "llvm/ADT/TypeSwitch.h"
26+
#include "llvm/Support/Casting.h"
2627
#include <optional>
2728

2829
namespace mlir {
@@ -826,19 +827,20 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
826827
}
827828

828829
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
829-
mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
830-
return mfmaOpToScaledIntrinsic(
831-
mfma.getSourceA().getType(), mfma.getSourceB().getType(),
832-
mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
833-
mfma.getBlocks(), chipset);
834-
}
835-
836-
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
837-
mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
838-
return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
839-
smfma.getSourceB().getType(),
840-
smfma.getDestC().getType(), smfma.getM(),
841-
smfma.getN(), smfma.getK(), 1u, chipset);
830+
mfmaOpToScaledIntrinsic(Operation *op, Chipset chipset) {
831+
if (auto mfma = llvm::dyn_cast_or_null<MFMAOp>(op)) {
832+
return mfmaOpToScaledIntrinsic(
833+
mfma.getSourceA().getType(), mfma.getSourceB().getType(),
834+
mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
835+
mfma.getBlocks(), chipset);
836+
}
837+
if (auto smfma = llvm::dyn_cast_or_null<ScaledMFMAOp>(op)) {
838+
return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
839+
smfma.getSourceB().getType(),
840+
smfma.getDestC().getType(), smfma.getM(),
841+
smfma.getN(), smfma.getK(), 1u, chipset);
842+
}
843+
return std::nullopt;
842844
}
843845

844846
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
@@ -964,7 +966,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
964966

965967
struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
966968
ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
967-
: ConvertOpToLLVMPattern<ScaledMFMAOp>(converter), chipset(chipset) {}
969+
: ConvertOpToLLVMPattern(converter), chipset(chipset) {}
968970

969971
Chipset chipset;
970972

@@ -986,7 +988,7 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
986988
return op.emitOpError(
987989
"no intrinsic matching Scaled MFMA size on given chipset");
988990

989-
StringRef intrinsicName = std::get<0>(*maybeScaledIntrinsic);
991+
auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
990992
OperationState loweredOp(loc, intrinsicName);
991993
loweredOp.addTypes(intrinsicOutType);
992994
loweredOp.addOperands(
@@ -997,7 +999,6 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
997999
Value scaleB = createI32Constant(rewriter, loc, adaptor.getScaleB());
9981000
Value opselA = createI32Constant(rewriter, loc, adaptor.getOpselA());
9991001
Value opselB = createI32Constant(rewriter, loc, adaptor.getOpselB());
1000-
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
10011002
loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
10021003
createI32Constant(rewriter, loc, bTypeCode),
10031004
/*scale A byte=*/opselA, /*scale A=*/scaleA,

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -507,44 +507,35 @@ LogicalResult GatherToLDSOp::verify() {
507507
}
508508

509509
LogicalResult ScaledMFMAOp::verify() {
510-
unsigned opselA = getOpselA();
511-
unsigned opselB = getOpselB();
512-
513-
opselA >>= 8;
514-
opselB >>= 8;
510+
unsigned opselA = getOpselA() >> 8;
511+
unsigned opselB = getOpselB() >> 8;
515512

516513
if (opselA != 0)
517-
return emitOpError("Opsel A must be a zero extended 8 bit value.");
514+
return emitOpError("Opsel A must be a zero extended 8 bit value");
518515

519516
if (opselB != 0)
520-
return emitOpError("Opsel B must be a zero extended 8 bit value.");
521-
522-
auto validType = [&](Type mlirElemType) {
523-
return llvm::TypeSwitch<Type, bool>(mlirElemType)
524-
.Case([](Float8E4M3FNType) { return true; })
525-
.Case([](Float8E5M2Type) { return true; })
526-
.Case([](Float6E2M3FNType) { return true; })
527-
.Case([](Float6E3M2FNType) { return true; })
528-
.Case([](Float4E2M1FNType) { return true; })
529-
.Default([](Type) { return false; });
530-
};
517+
return emitOpError("Opsel B must be a zero extended 8 bit value");
518+
519+
auto isValidType =
520+
llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type, Float6E2M3FNType,
521+
Float6E3M2FNType, Float4E2M1FNType>;
531522

532523
Type aType = getSourceA().getType();
533524
Type bType = getSourceB().getType();
534525
aType = getElementTypeOrSelf(aType);
535526
bType = getElementTypeOrSelf(bType);
536-
if (!validType(aType))
537-
return emitOpError("Source A must be of element type fp4, fp6 or fp8.");
538-
if (!validType(bType))
539-
return emitOpError("Source B must be of element type fp4, fp6 or fp8.");
527+
if (!isValidType(aType))
528+
return emitOpError("Source A must be of element type fp4, fp6 or fp8");
529+
if (!isValidType(bType))
530+
return emitOpError("Source B must be of element type fp4, fp6 or fp8");
540531

541532
unsigned m = getM();
542533
unsigned n = getN();
543534
unsigned k = getK();
544535
bool tileConfig1 = (m == n && n == 32 && k == 64);
545536
bool tileConfig2 = (m == n && n == 16 && k == 128);
546537
if (!tileConfig1 && !tileConfig2)
547-
return emitOpError("Invalid tile size for scaled mfma.");
538+
return emitOpError("Invalid tile size for scaled mfma");
548539

549540
return success();
550541
}

0 commit comments

Comments
 (0)