Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 137 additions & 40 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -582,51 +582,148 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f

//===---------------------------------------------------------------------===//
// WMMA intrinsics
class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
list<Trait> traits = []> :
ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
Arguments<(ins Variadic<LLVM_Type>:$args)> {
let assemblyFormat =
"$args attr-dict `:` functional-type($args, $res)";
class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
[0], [0], [], 1, 0, 0, 0, [], []>,
Arguments<(ins
LLVM_ScalarOrVectorOf<AB>:$A,
LLVM_ScalarOrVectorOf<AB>:$B,
LLVM_ScalarOrVectorOf<CD>:$C)> {
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
let assemblyFormat = [{
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
}];
}

class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
[0], [1], [], 1, 0, 0, 0, [3], ["opsel"]>,
Arguments<(ins
LLVM_ScalarOrVectorOf<AB>:$A,
LLVM_ScalarOrVectorOf<AB>:$B,
LLVM_ScalarOrVectorOf<CD>:$C,
DefaultValuedAttr<I1Attr, "0">:$opsel)> {
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
let assemblyFormat = [{
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
}];
}

class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
[0], [1], [], 1, 0, 0, 0, [0, 2, 5], ["signA", "signB", "clamp"]>,
Arguments<(ins
DefaultValuedAttr<I1Attr, "0">:$signA,
LLVM_ScalarOrVectorOf<AB>:$A,
DefaultValuedAttr<I1Attr, "0">:$signB,
LLVM_ScalarOrVectorOf<AB>:$B,
LLVM_ScalarOrVectorOf<CD>:$C,
DefaultValuedAttr<I1Attr, "0">:$clamp)> {
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
let assemblyFormat = [{
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
}];
}

class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
[0], [1], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
Arguments<(ins
DefaultValuedAttr<I1Attr, "0">:$signA,
LLVM_ScalarOrVectorOf<AB>:$A,
DefaultValuedAttr<I1Attr, "0">:$signB,
LLVM_ScalarOrVectorOf<AB>:$B,
DefaultValuedAttr<I16Attr, "0">:$modC,
LLVM_ScalarOrVectorOf<CD>:$C,
DefaultValuedAttr<I1Attr, "0">:$reuseA,
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
let assemblyFormat = [{
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
}];
}

class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
[0], [0], [], 1, 0, 0, 0, [2, 4, 5], ["modC","reuseA","reuseB"]>,
Arguments<(ins
LLVM_ScalarOrVectorOf<AB>:$A,
LLVM_ScalarOrVectorOf<AB>:$B,
DefaultValuedAttr<I16Attr, "0">:$modC,
LLVM_ScalarOrVectorOf<CD>:$C,
DefaultValuedAttr<I1Attr, "0">:$reuseA,
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
let assemblyFormat = [{
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
}];
}

class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> : ROCDL_IntrOp<mnemonic,
[0], [1, 5], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
Arguments<(ins
DefaultValuedAttr<I1Attr, "0">:$signA,
LLVM_ScalarOrVectorOf<AB>:$A,
DefaultValuedAttr<I1Attr, "0">:$signB,
LLVM_ScalarOrVectorOf<AB>:$B,
DefaultValuedAttr<I16Attr, "0">:$modC,
LLVM_ScalarOrVectorOf<C>:$C,
DefaultValuedAttr<I1Attr, "0">:$reuseA,
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
let results = (outs LLVM_ScalarOrVectorOf<D>:$res);
let assemblyFormat = [{
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
}];
}

class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
[0], [1], [], 1, 0, 0, 0, [0, 2, 5, 6], ["signA", "signB", "reuseA","reuseB"]>,
Arguments<(ins
DefaultValuedAttr<I1Attr, "0">:$signA,
LLVM_ScalarOrVectorOf<AB>:$A,
DefaultValuedAttr<I1Attr, "0">:$signB,
LLVM_ScalarOrVectorOf<AB>:$B,
LLVM_ScalarOrVectorOf<CD>:$C,
DefaultValuedAttr<I1Attr, "0">:$reuseA,
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
let assemblyFormat = [{
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
}];
}

// Available from gfx11
def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.f16", /*Type AB=*/F16, /*Type CD=*/F32>;
def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf16", AnyInteger, F32>;
def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.f16.16x16x16.f16", F16, F16>;
def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.bf16.16x16x16.bf16", AnyInteger, AnyInteger>;
def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu8", AnyInteger, AnyInteger>;
def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu4", AnyInteger, AnyInteger>;
// Available from gfx12
def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>;
def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_fp8", AnyInteger, F32>;
def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_bf8", AnyInteger, F32>;
def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_bf8", AnyInteger, F32>;
def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_fp8", AnyInteger, F32>;
def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x32.iu4", AnyInteger, AnyInteger>;
// Available from gfx1250
def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>;
def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>;
def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>;
def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>;
def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>;
def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;
def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>;
def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>;
def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>;
def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>;
def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>;
def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>;
def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>;
def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>;
def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>;
def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>;
def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>;
def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>;
def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>;
def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>;
def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x4.f32", F32, F32>;
def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.bf16", BF16, F32>;
def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.f16", F16, F32>;
def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f16.16x16x32.f16", F16, F16>;
def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.bf16.16x16x32.bf16", BF16, BF16>;
def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Diff_IntrOp<"wmma.bf16f32.16x16x32.bf16", BF16, /*Type C=*/F32, /*Type D=*/BF16>;
def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_fp8", AnyInteger, F32>;
def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_bf8", AnyInteger, F32>;
def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_fp8", AnyInteger, F32>;
def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_bf8", AnyInteger, F32>;
def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_fp8", AnyInteger, F16>;
def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_bf8", AnyInteger, F16>;
def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_fp8", AnyInteger, F16>;
def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_bf8", AnyInteger, F16>;
def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_fp8", AnyInteger, F32>;
def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_bf8", AnyInteger, F32>;
def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_fp8", AnyInteger, F32>;
def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_bf8", AnyInteger, F32>;
def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_fp8", AnyInteger, F16>;
def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_bf8", AnyInteger, F16>;
def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_fp8", AnyInteger, F16>;
def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16>;
def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp<"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger>;

//===---------------------------------------------------------------------===//
// LDS transpose intrinsics (available in GFX950)
Expand Down
85 changes: 49 additions & 36 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
Expand Down Expand Up @@ -79,12 +80,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter,
return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
}

static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
bool value) {
Type llvmI1 = rewriter.getI1Type();
return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
}

/// Returns the linear index used to access an element in the memref.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
Location loc, MemRefDescriptor &memRefDescriptor,
Expand Down Expand Up @@ -684,23 +679,18 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
/// intrinsics having been defined before the AMD backend supported bfloat. We
/// similarly need to pack 8-bit float types into integers as if they were i8
/// (which they are for the backend's purposes).
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
bool isUnsigned, Value llvmInput,
Value mlirInput,
SmallVector<Value, 4> &operands) {
static void wmmaPushInputOperand(
ConversionPatternRewriter &rewriter, Location loc,
const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
Value mlirInput, SmallVectorImpl<Value> &operands,
SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
Type inputType = llvmInput.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
if (!vectorType) {
operands.push_back(llvmInput);
return;
}
Type elemType = vectorType.getElementType();

if (elemType.isBF16())
llvmInput = LLVM::BitcastOp::create(
rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
if (elemType.getIntOrFloatBitWidth() > 8) {
operands.push_back(llvmInput);
return;
Expand All @@ -719,8 +709,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
} else if (elemType.isSignedInteger()) {
localIsUnsigned = false;
}
Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
operands.push_back(sign);
attrs.push_back(
NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
}

int64_t numBits =
Expand Down Expand Up @@ -751,18 +741,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
Value output, int32_t subwordOffset,
bool clamp, SmallVector<Value, 4> &operands) {
bool clamp, SmallVectorImpl<Value> &operands,
SmallVectorImpl<NamedAttribute> &attrs) {
Type inputType = output.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
output = LLVM::BitcastOp::create(
rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
operands.push_back(output);
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
attrs.push_back(
NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
} else if (elemType.isInteger(32)) {
operands.push_back(createI1Constant(rewriter, loc, clamp));
attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
}
}

Expand Down Expand Up @@ -1311,11 +1300,33 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
return op->emitOpError("WMMA only supported on gfx11 and gfx12");

// The WMMA operations represent vectors of bf16s as vectors of i16s, so we
// need to bitcast bfloats to i16 and then bitcast them back.
bool isGFX1250 = chipset >= Chipset(12, 5, 0);

// The WMMA operations represent vectors of bf16s as vectors of i16s
// (except on gfx1250), so we need to bitcast bfloats to i16 and then
// bitcast them back.
auto aType = cast<VectorType>(adaptor.getSourceA().getType());
auto bType = cast<VectorType>(adaptor.getSourceB().getType());
auto destCType = cast<VectorType>(adaptor.getDestC().getType());
bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
VectorType rawOutType = outType;
if (outType.getElementType().isBF16())
if (castOutToI16)
rawOutType = outType.clone(rewriter.getI16Type());
Value a = adaptor.getSourceA();
if (castAToI16)
a = LLVM::BitcastOp::create(rewriter, loc,
aType.clone(rewriter.getI16Type()), a);
Value b = adaptor.getSourceB();
if (castBToI16)
b = LLVM::BitcastOp::create(rewriter, loc,
bType.clone(rewriter.getI16Type()), b);
Value destC = adaptor.getDestC();
if (castDestCToI16)
destC = LLVM::BitcastOp::create(
rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);

std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);

Expand All @@ -1325,18 +1336,20 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
return op.emitOpError("subwordOffset not supported on gfx12+");

OperationState loweredOp(loc, *maybeIntrinsic);
loweredOp.addTypes(rawOutType);

SmallVector<Value, 4> operands;
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
adaptor.getSourceA(), op.getSourceA(), operands);
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
adaptor.getSourceB(), op.getSourceB(), operands);
wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
op.getSubwordOffset(), op.getClamp(), operands);
SmallVector<NamedAttribute, 4> attrs;
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
op.getSourceA(), operands, attrs, "signA");
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
op.getSourceB(), operands, attrs, "signB");
wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
op.getSubwordOffset(), op.getClamp(), operands,
attrs);

OperationState loweredOp(loc, *maybeIntrinsic);
loweredOp.addTypes(rawOutType);
loweredOp.addOperands(operands);
loweredOp.addAttributes(attrs);
Operation *lowered = rewriter.create(loweredOp);

Operation *maybeCastBack = lowered;
Expand Down
Loading
Loading