Skip to content

Commit 846c389

Browse files
PR review round 2
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 3ba7ea8 commit 846c389

File tree

4 files changed

+62
-80
lines changed

4 files changed

+62
-80
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,11 @@ def MFMAOutTypes : AnyTypeOf<[F64,
687687
VectorOfLengthAndType<[4, 16, 32], [F32]>,
688688
VectorOfLengthAndType<[4, 16, 32], [I32]>,
689689
VectorOfLengthAndType<[4], [F64]>]>;
690+
// scaled_mfma
691+
def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>,
692+
VectorOfLengthAndType<[8, 32], [F8E5M2, F8E4M3FN]>,
693+
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
694+
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16, 32], [F32]>]>;
690695
// wmma
691696
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
692697
[4, 8, 16],
@@ -837,14 +842,14 @@ def AMDGPU_ScaledMFMAOp :
837842
I32Attr:$m,
838843
I32Attr:$n,
839844
I32Attr:$k,
840-
MFMAInTypes:$sourceA,
841-
MFMAInTypes:$sourceB,
842-
MFMAOutTypes:$destC,
843-
I32Attr:$scaleA,
844-
I32Attr:$scaleB,
845-
I32Attr:$opselA,
846-
I32Attr:$opselB)>,
847-
Results<(outs MFMAOutTypes: $destD)> {
845+
ScaledMFMAInTypes:$sourceA,
846+
ScaledMFMAInTypes:$sourceB,
847+
ScaledMFMAOutTypes:$destC,
848+
AnyTypeOf<[I8, FixedVectorOfLengthAndType<[4], [I8]>]>:$scalesA,
849+
AnyTypeOf<[I8, FixedVectorOfLengthAndType<[4], [I8]>]>:$scalesB,
850+
I32Attr:$scalesIdxA,
851+
I32Attr:$scalesIdxB)>,
852+
Results<(outs ScaledMFMAOutTypes: $destD)> {
848853
let summary = "MLIR wrapper for CDNA scaled mfma instructions";
849854
let description = [{
850855
The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
@@ -872,9 +877,9 @@ def AMDGPU_ScaledMFMAOp :
872877
double-precision operations on gfx94x and so are not included here.
873878
}];
874879
let assemblyFormat = [{
875-
$sourceA `*` $sourceB `+` $destC
880+
`(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*` `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC
876881
attr-dict
877-
`:` type($sourceA) `,` type($sourceB) `,` type($destC)
882+
`:` type($sourceA) `,` type($scalesA) `,` type($sourceB) `,` type($scalesB) `,` type($destC)
878883
}];
879884
let hasVerifier = 1;
880885
}

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -974,19 +974,15 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
974974
matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
975975
ConversionPatternRewriter &rewriter) const override {
976976
Location loc = op.getLoc();
977-
Type outType = typeConverter->convertType(op.getDestD().getType());
978-
Type intrinsicOutType = outType;
979-
if (auto outVecType = dyn_cast<VectorType>(outType))
980-
if (outVecType.getElementType().isBF16())
981-
intrinsicOutType = outVecType.clone(rewriter.getI16Type());
977+
Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
982978

983-
if (chipset.majorVersion != 9 || chipset < kGfx908)
984-
return op->emitOpError("Scaled MFMA only supported on gfx908+");
979+
if (chipset.majorVersion != 9 || chipset < kGfx950)
980+
return op->emitOpError("scaled MFMA only supported on gfx908+");
985981
std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
986982
maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
987983
if (!maybeScaledIntrinsic.has_value())
988984
return op.emitOpError(
989-
"no intrinsic matching Scaled MFMA size on given chipset");
985+
"no intrinsic matching scaled MFMA size on given chipset");
990986

991987
auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
992988
OperationState loweredOp(loc, intrinsicName);
@@ -995,17 +991,18 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
995991
{convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
996992
convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
997993
adaptor.getDestC()});
998-
Value scaleA = createI32Constant(rewriter, loc, adaptor.getScaleA());
999-
Value scaleB = createI32Constant(rewriter, loc, adaptor.getScaleB());
1000-
Value opselA = createI32Constant(rewriter, loc, adaptor.getOpselA());
1001-
Value opselB = createI32Constant(rewriter, loc, adaptor.getOpselB());
1002-
loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
1003-
createI32Constant(rewriter, loc, bTypeCode),
1004-
/*scale A byte=*/opselA, /*scale A=*/scaleA,
1005-
/*scale B byte=*/opselB, /*scale B=*/scaleB});
994+
Value scalesIdxA = createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
995+
Value scalesIdxB = createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
996+
loweredOp.addOperands(
997+
{createI32Constant(rewriter, loc, aTypeCode),
998+
createI32Constant(rewriter, loc, bTypeCode),
999+
/*scales A*/
1000+
convertMFMAVectorOperand(rewriter, loc, adaptor.getScalesA()),
1001+
/*scales B*/
1002+
convertMFMAVectorOperand(rewriter, loc, adaptor.getScalesB()),
1003+
/*scales idx A=*/scalesIdxA,
1004+
/*scales idx B=*/scalesIdxB});
10061005
Value lowered = rewriter.create(loweredOp)->getResult(0);
1007-
if (outType != intrinsicOutType)
1008-
lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
10091006
rewriter.replaceOp(op, lowered);
10101007
return success();
10111008
}

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -507,35 +507,14 @@ LogicalResult GatherToLDSOp::verify() {
507507
}
508508

509509
LogicalResult ScaledMFMAOp::verify() {
510-
unsigned opselA = getOpselA() >> 8;
511-
unsigned opselB = getOpselB() >> 8;
512-
513-
if (opselA != 0)
514-
return emitOpError("Opsel A must be a zero extended 8 bit value");
515-
516-
if (opselB != 0)
517-
return emitOpError("Opsel B must be a zero extended 8 bit value");
518-
519-
auto isValidType =
520-
llvm::IsaPred<Float8E4M3FNType, Float8E5M2Type, Float6E2M3FNType,
521-
Float6E3M2FNType, Float4E2M1FNType>;
522-
523-
Type aType = getSourceA().getType();
524-
Type bType = getSourceB().getType();
525-
aType = getElementTypeOrSelf(aType);
526-
bType = getElementTypeOrSelf(bType);
527-
if (!isValidType(aType))
528-
return emitOpError("Source A must be of element type fp4, fp6 or fp8");
529-
if (!isValidType(bType))
530-
return emitOpError("Source B must be of element type fp4, fp6 or fp8");
531-
532-
unsigned m = getM();
533-
unsigned n = getN();
534-
unsigned k = getK();
535-
bool tileConfig1 = (m == n && n == 32 && k == 64);
536-
bool tileConfig2 = (m == n && n == 16 && k == 128);
537-
if (!tileConfig1 && !tileConfig2)
538-
return emitOpError("Invalid tile size for scaled mfma");
510+
unsigned scalesIdxA = getScalesIdxA();
511+
unsigned scalesIdxB = getScalesIdxB();
512+
513+
if (scalesIdxA > 3)
514+
return emitOpError("scales idx A must be a value from 0 to 3 inclusive");
515+
516+
if (scalesIdxB > 3)
517+
return emitOpError("scales idx B must be a value from 0 to 3 inclusive");
539518

540519
return success();
541520
}

mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -55,46 +55,47 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
5555
func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
5656
%arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
5757
%arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
58-
%arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>) {
58+
%arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>,
59+
%arg7 : vector<4xi8>, %arg8 : i8) {
5960

60-
// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
61-
// CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
6261
// CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
62+
// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
63+
// CHECK: llvm.bitcast
6364

64-
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
65-
amdgpu.scaled_mfma %arg2 * %arg2 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
66-
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
67-
amdgpu.scaled_mfma %arg2 * %arg2 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
65+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
66+
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg2 ) * ( %arg8 [ 1 ] * %arg2 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf8E4M3FN>, vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<16xf32>
67+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
68+
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg2 ) * ( %arg8 [ 1 ] * %arg2 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf8E4M3FN>, vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<4xf32>
6869

6970
// CHECK: llvm.bitcast
7071

71-
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
72-
amdgpu.scaled_mfma %arg3 * %arg3 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
73-
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
74-
amdgpu.scaled_mfma %arg3 * %arg3 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
72+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
73+
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg3 ) * ( %arg8 [ 1 ] * %arg3 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf8E5M2>, vector<4xi8>, vector<32xf8E5M2>, i8, vector<16xf32>
74+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
75+
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg3 ) * ( %arg8 [ 1 ] * %arg3 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf8E5M2>, vector<4xi8>, vector<32xf8E5M2>, i8, vector<4xf32>
7576

7677
// CHECK: llvm.bitcast
7778

78-
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
79-
amdgpu.scaled_mfma %arg4 * %arg4 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
80-
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
81-
amdgpu.scaled_mfma %arg4 * %arg4 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
79+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
80+
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg4 ) * ( %arg8 [ 1 ] * %arg4 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf6E2M3FN>, vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<16xf32>
81+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
82+
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg4 ) * ( %arg8 [ 1 ] * %arg4 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf6E2M3FN>, vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<4xf32>
8283

8384
// CHECK: llvm.bitcast
8485
// CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
8586

86-
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
87-
amdgpu.scaled_mfma %arg5 * %arg5 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
88-
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
89-
amdgpu.scaled_mfma %arg5 * %arg5 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
87+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
88+
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg5 ) * ( %arg8 [ 1 ] * %arg5 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf6E3M2FN>, vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<16xf32>
89+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
90+
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg5 ) * ( %arg8 [ 1 ] * %arg5 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf6E3M2FN>, vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<4xf32>
9091

9192
// CHECK: llvm.bitcast
9293
// CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
9394

94-
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
95-
amdgpu.scaled_mfma %arg6 * %arg6 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
96-
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
97-
amdgpu.scaled_mfma %arg6 * %arg6 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32, scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
95+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
96+
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg6 ) * ( %arg8 [ 1 ] * %arg6 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf4E2M1FN>, vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<16xf32>
97+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
98+
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg6 ) * ( %arg8 [ 1 ] * %arg6 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf4E2M1FN>, vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<4xf32>
9899

99100
func.return
100101
}

0 commit comments

Comments
 (0)