Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,11 @@ def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
VectorOfLengthAndType<[4], [F64]>]>;
// scaled_mfma
def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>,
VectorOfLengthAndType<[8, 32], [F8E5M2, F8E4M3FN]>,
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16, 32], [F32]>]>;
// wmma
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
[4, 8, 16],
Expand Down Expand Up @@ -804,7 +809,7 @@ def AMDGPU_GatherToLDSOp :
TypeAttr:$transferType
)>,
Results<(outs)> {
let summary = "MLIR wrapper for CDNA mfma instructions";
let summary = "MLIR wrapper for CDNA Gather to LDS instructions";
let description = [{
The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.

Expand All @@ -830,4 +835,52 @@ def AMDGPU_GatherToLDSOp :
let hasVerifier = 1;
}

def AMDGPU_ScaledMFMAOp :
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
Arguments<(ins
I32Attr:$m,
I32Attr:$n,
I32Attr:$k,
ScaledMFMAInTypes:$sourceA,
ScaledMFMAInTypes:$sourceB,
ScaledMFMAOutTypes:$destC,
AnyTypeOf<[I8, FixedVectorOfLengthAndType<[4], [I8]>]>:$scalesA,
AnyTypeOf<[I8, FixedVectorOfLengthAndType<[4], [I8]>]>:$scalesB,
I32Attr:$scalesIdxA,
I32Attr:$scalesIdxB)>,
Results<(outs ScaledMFMAOutTypes: $destD)> {
let summary = "MLIR wrapper for CDNA scaled mfma instructions";
let description = [{
The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
multiple outer products in order to allow fast matrix multiplication.

The wrapper will select an appropriate `mfma` instruction, if one is available,
based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
types of the source and destination arguments.

Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
intrinsics that take an integer type of width `4K`. For example,
one can provide a `vector<4xi8>` as an argument to an MFMA instruction that
logically takes 4 i8s but whose intrinsics are specified to take an i32.
In these cases, the bytes in the vector will be concatenated in little-endian
order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).

This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
- `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and
fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile
size.
- `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
are omitted from this wrapper.
- The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported for
double-precision operations on gfx94x and so are not included here.
}];
let assemblyFormat = [{
`(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*` `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC
attr-dict
`:` type($sourceA) `,` type($scalesA) `,` type($sourceB) `,` type($scalesB) `,` type($destC)
}];
let hasVerifier = 1;
}
#endif // AMDGPU
71 changes: 63 additions & 8 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include <optional>

namespace mlir {
Expand Down Expand Up @@ -826,11 +827,20 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
}

static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
return mfmaOpToScaledIntrinsic(
mfma.getSourceA().getType(), mfma.getSourceB().getType(),
mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
mfma.getBlocks(), chipset);
mfmaOpToScaledIntrinsic(Operation *op, Chipset chipset) {
if (auto mfma = llvm::dyn_cast_or_null<MFMAOp>(op)) {
return mfmaOpToScaledIntrinsic(
mfma.getSourceA().getType(), mfma.getSourceB().getType(),
mfma.getDestC().getType(), mfma.getM(), mfma.getN(), mfma.getK(),
mfma.getBlocks(), chipset);
}
if (auto smfma = llvm::dyn_cast_or_null<ScaledMFMAOp>(op)) {
return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
smfma.getSourceB().getType(),
smfma.getDestC().getType(), smfma.getM(),
smfma.getN(), smfma.getK(), 1u, chipset);
}
return std::nullopt;
}

/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
Expand Down Expand Up @@ -954,6 +964,50 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
}
};

struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern(converter), chipset(chipset) {}

Chipset chipset;

LogicalResult
matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());

if (chipset.majorVersion != 9 || chipset < kGfx950)
return op->emitOpError("scaled MFMA only supported on gfx908+");
std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
if (!maybeScaledIntrinsic.has_value())
return op.emitOpError(
"no intrinsic matching scaled MFMA size on given chipset");

auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
{convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC()});
Value scalesIdxA = createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
Value scalesIdxB = createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
loweredOp.addOperands(
{createI32Constant(rewriter, loc, aTypeCode),
createI32Constant(rewriter, loc, bTypeCode),
/*scales A*/
convertMFMAVectorOperand(rewriter, loc, adaptor.getScalesA()),
/*scales B*/
convertMFMAVectorOperand(rewriter, loc, adaptor.getScalesB()),
/*scales idx A=*/scalesIdxA,
/*scales idx B=*/scalesIdxB});
Value lowered = rewriter.create(loweredOp)->getResult(0);
rewriter.replaceOp(op, lowered);
return success();
}
};

struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
Expand Down Expand Up @@ -1474,8 +1528,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering>(converter, chipset);
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
13 changes: 13 additions & 0 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,19 @@ LogicalResult GatherToLDSOp::verify() {
return success();
}

LogicalResult ScaledMFMAOp::verify() {
unsigned scalesIdxA = getScalesIdxA();
unsigned scalesIdxB = getScalesIdxB();

if (scalesIdxA > 3)
return emitOpError("scales idx A must be a value from 0 to 3 inclusive");

if (scalesIdxB > 3)
return emitOpError("scales idx B must be a value from 0 to 3 inclusive");

return success();
}

#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"

#define GET_ATTRDEF_CLASSES
Expand Down
48 changes: 48 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,51 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,

func.return
}

func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
%arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
%arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
%arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>,
%arg7 : vector<4xi8>, %arg8 : i8) {

// CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.bitcast

// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg2 ) * ( %arg8 [ 1 ] * %arg2 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf8E4M3FN>, vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg2 ) * ( %arg8 [ 1 ] * %arg2 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf8E4M3FN>, vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<4xf32>

// CHECK: llvm.bitcast

// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg3 ) * ( %arg8 [ 1 ] * %arg3 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf8E5M2>, vector<4xi8>, vector<32xf8E5M2>, i8, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg3 ) * ( %arg8 [ 1 ] * %arg3 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf8E5M2>, vector<4xi8>, vector<32xf8E5M2>, i8, vector<4xf32>

// CHECK: llvm.bitcast

// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg4 ) * ( %arg8 [ 1 ] * %arg4 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf6E2M3FN>, vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg4 ) * ( %arg8 [ 1 ] * %arg4 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf6E2M3FN>, vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<4xf32>

// CHECK: llvm.bitcast
// CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32

// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg5 ) * ( %arg8 [ 1 ] * %arg5 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf6E3M2FN>, vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg5 ) * ( %arg8 [ 1 ] * %arg5 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf6E3M2FN>, vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<4xf32>

// CHECK: llvm.bitcast
// CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32

// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i8, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg6 ) * ( %arg8 [ 1 ] * %arg6 ) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<32xf4E2M1FN>, vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i8, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg6 ) * ( %arg8 [ 1 ] * %arg6 ) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<32xf4E2M1FN>, vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<4xf32>

func.return
}
Loading