Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,10 @@ def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [F32]>,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
VectorOfLengthAndType<[4], [F64]>]>;
// scaled_mfma
def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
// wmma
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
[4, 8, 16],
Expand Down Expand Up @@ -804,7 +808,7 @@ def AMDGPU_GatherToLDSOp :
TypeAttr:$transferType
)>,
Results<(outs)> {
let summary = "MLIR wrapper for CDNA mfma instructions";
let summary = "MLIR wrapper for CDNA Gather to LDS instructions";
let description = [{
The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.

Expand All @@ -830,4 +834,52 @@ def AMDGPU_GatherToLDSOp :
let hasVerifier = 1;
}

def AMDGPU_ScaledMFMAOp :
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
Arguments<(ins
I32Attr:$m,
I32Attr:$n,
I32Attr:$k,
ScaledMFMAInTypes:$sourceA,
ScaledMFMAInTypes:$sourceB,
ScaledMFMAOutTypes:$destC,
AnyTypeOf<[I8, FixedVectorOfLengthAndType<[4], [I8]>]>:$scalesA,
AnyTypeOf<[I8, FixedVectorOfLengthAndType<[4], [I8]>]>:$scalesB,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$scalesIdxA,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$scalesIdxB
)>,
Results<(outs ScaledMFMAOutTypes: $destD)> {
let summary = "MLIR wrapper for CDNA scaled mfma instructions";
let description = [{
The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
multiple outer products in order to allow fast matrix multiplication.

The wrapper will select an appropriate `mfma` instruction, if one is available,
based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
types of the source and destination arguments.

Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
intrinsics that take an integer type of width `4K`. For example,
one can provide a `vector<4xi8>` as an argument to an MFMA instruction that
logically takes 4 i8s but whose intrinsics are specified to take an i32.
In these cases, the bytes in the vector will be concatenated in little-endian
order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).

This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
- `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and
fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile
size.
- `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
are omitted from this wrapper.
- The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported for
double-precision operations on gfx94x and so are not included here.
}];
let assemblyFormat = [{
`(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*` `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC
attr-dict
`:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
}];
}
#endif // AMDGPU
81 changes: 78 additions & 3 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include <optional>

namespace mlir {
Expand Down Expand Up @@ -528,6 +529,25 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
return input;
}

/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
///
/// Specifically:
/// 1. If `input` is a i8 value, zero extend it to i32
/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
///
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
static Value castScaledMFMAVectorOperand(ConversionPatternRewriter &rewriter,
Location loc, Value input) {
Type inputType = input.getType();
Type outputType = rewriter.getI32Type();
if (auto intType = dyn_cast<IntegerType>(inputType))
return rewriter.create<LLVM::ZExtOp>(loc, outputType, input);
return rewriter.create<LLVM::BitcastOp>(loc, outputType, input);
}

/// Push an input operand. If it is a float type, nothing to do. If it is
/// an integer type, then we need to also push its signdness (1 for signed, 0
/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
Expand Down Expand Up @@ -833,6 +853,14 @@ mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
mfma.getBlocks(), chipset);
}

static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than creating an overloaded function, you can pass the operation* and then do the dyn_cast + if_else.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the benefit of branching on the operation type?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be templatized as well; my point was to remove the unnecessary function that does the same thing.

Copy link
Contributor Author

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, templating doesn't remove the need for branching since I would still be branching off the class of the input op to pass in the appropriate arguments (they differ between the proposed scaled_mfma and the existing mfma). So, I opted for the branching on op type which I find clearer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the reason this isn't a pure template<typename MfmaOp> is because the arguments don't match (MFMA has a getBlocks(), ScalingMfma doesn't, for example)

smfma.getSourceB().getType(),
smfma.getDestC().getType(), smfma.getM(),
smfma.getN(), smfma.getK(), 1u, chipset);
}

/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
Expand Down Expand Up @@ -954,6 +982,52 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
}
};

struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern(converter), chipset(chipset) {}

Chipset chipset;

LogicalResult
matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());

if (chipset.majorVersion != 9 || chipset < kGfx950)
return op->emitOpError("scaled MFMA only supported on gfx908+");
std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
if (!maybeScaledIntrinsic.has_value())
return op.emitOpError(
"no intrinsic matching scaled MFMA size on given chipset");

auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
{convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC()});
Value scalesIdxA =
createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
Value scalesIdxB =
createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
loweredOp.addOperands(
{createI32Constant(rewriter, loc, aTypeCode),
createI32Constant(rewriter, loc, bTypeCode),
/*scales idx A=*/scalesIdxA,
/*scales A*/
castScaledMFMAVectorOperand(rewriter, loc, adaptor.getScalesA()),
/*scales idx B=*/scalesIdxB,
/*scales B*/
castScaledMFMAVectorOperand(rewriter, loc, adaptor.getScalesB())});
Value lowered = rewriter.create(loweredOp)->getResult(0);
rewriter.replaceOp(op, lowered);
return success();
}
};

struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
Expand Down Expand Up @@ -1474,8 +1548,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering>(converter, chipset);
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}
49 changes: 49 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,52 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,

func.return
}

func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
%arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
%arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
%arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>,
%arg7 : vector<4xi8>, %arg8 : i8) {

// CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[c2:.+]] = llvm.bitcast{{.*}} : vector<4xi8> to i32
// CHECK: %[[c3:.+]] = llvm.zext{{.*}} : i8 to i32

// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<32xf8E4M3FN>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf8E4M3FN>, i8, vector<32xf8E4M3FN>, vector<4xf32>

// CHECK: llvm.bitcast

// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf8E5M2>, i8, vector<32xf8E5M2>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf8E5M2>, i8, vector<32xf8E5M2>, vector<4xf32>

// CHECK: llvm.bitcast

// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<32xf6E2M3FN>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf6E2M3FN>, i8, vector<32xf6E2M3FN>, vector<4xf32>

// CHECK: llvm.bitcast
// CHECK: llvm.mlir.constant(3 : i32) : i32

// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<32xf6E3M2FN>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf6E3M2FN>, i8, vector<32xf6E3M2FN>, vector<4xf32>

// CHECK: llvm.bitcast
// CHECK: llvm.mlir.constant(4 : i32) : i32

// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<32xf4E2M1FN>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[c1]], %[[c3]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xi8>, vector<32xf4E2M1FN>, i8, vector<32xf4E2M1FN>, vector<4xf32>

func.return
}
7 changes: 7 additions & 0 deletions mlir/test/Dialect/AMDGPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,10 @@ func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
func.return %0 : f32
}

// CHECK-LABEL: func @scaled_mfma
func.func @scaled_mfma(%arg0 : i8, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
// CHECK: amdgpu.scaled_mfma
%0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : i8, vector<32xf6E2M3FN>, i8, vector<32xf6E2M3FN>, vector<16xf32>
func.return %0 : vector<16xf32>
}