Skip to content

Commit b1262d1

Browse files
[mlir][ROCDL] Refactor wmma intrinsics to use attributes not operands where possible (#167041)
The current implementation of the WMMA intrinsic ops as they are defined in the ROCDL tablegen is incorrect. They represent as operands what should be attributes such as `clamp`, `opsel`, `signA/signB`. This change performs a refactoring to bring it in line with what we expect. --------- Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 630dfc9 commit b1262d1

File tree

7 files changed

+293
-186
lines changed

7 files changed

+293
-186
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 137 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -596,51 +596,148 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
596596

597597
//===---------------------------------------------------------------------===//
598598
// WMMA intrinsics
599-
class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
600-
list<Trait> traits = []> :
601-
ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
602-
Arguments<(ins Variadic<LLVM_Type>:$args)> {
603-
let assemblyFormat =
604-
"$args attr-dict `:` functional-type($args, $res)";
599+
class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
600+
[0], [0], [], 1, 0, 0, 0, [], []>,
601+
Arguments<(ins
602+
LLVM_ScalarOrVectorOf<AB>:$A,
603+
LLVM_ScalarOrVectorOf<AB>:$B,
604+
LLVM_ScalarOrVectorOf<CD>:$C)> {
605+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
606+
let assemblyFormat = [{
607+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
608+
}];
609+
}
610+
611+
class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
612+
[0], [1], [], 1, 0, 0, 0, [3], ["opsel"]>,
613+
Arguments<(ins
614+
LLVM_ScalarOrVectorOf<AB>:$A,
615+
LLVM_ScalarOrVectorOf<AB>:$B,
616+
LLVM_ScalarOrVectorOf<CD>:$C,
617+
DefaultValuedAttr<I1Attr, "0">:$opsel)> {
618+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
619+
let assemblyFormat = [{
620+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
621+
}];
622+
}
623+
624+
class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
625+
[0], [1], [], 1, 0, 0, 0, [0, 2, 5], ["signA", "signB", "clamp"]>,
626+
Arguments<(ins
627+
DefaultValuedAttr<I1Attr, "0">:$signA,
628+
LLVM_ScalarOrVectorOf<AB>:$A,
629+
DefaultValuedAttr<I1Attr, "0">:$signB,
630+
LLVM_ScalarOrVectorOf<AB>:$B,
631+
LLVM_ScalarOrVectorOf<CD>:$C,
632+
DefaultValuedAttr<I1Attr, "0">:$clamp)> {
633+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
634+
let assemblyFormat = [{
635+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
636+
}];
637+
}
638+
639+
class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
640+
[0], [1], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
641+
Arguments<(ins
642+
DefaultValuedAttr<I1Attr, "0">:$signA,
643+
LLVM_ScalarOrVectorOf<AB>:$A,
644+
DefaultValuedAttr<I1Attr, "0">:$signB,
645+
LLVM_ScalarOrVectorOf<AB>:$B,
646+
DefaultValuedAttr<I16Attr, "0">:$modC,
647+
LLVM_ScalarOrVectorOf<CD>:$C,
648+
DefaultValuedAttr<I1Attr, "0">:$reuseA,
649+
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
650+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
651+
let assemblyFormat = [{
652+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
653+
}];
654+
}
655+
656+
class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
657+
[0], [0], [], 1, 0, 0, 0, [2, 4, 5], ["modC","reuseA","reuseB"]>,
658+
Arguments<(ins
659+
LLVM_ScalarOrVectorOf<AB>:$A,
660+
LLVM_ScalarOrVectorOf<AB>:$B,
661+
DefaultValuedAttr<I16Attr, "0">:$modC,
662+
LLVM_ScalarOrVectorOf<CD>:$C,
663+
DefaultValuedAttr<I1Attr, "0">:$reuseA,
664+
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
665+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
666+
let assemblyFormat = [{
667+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
668+
}];
669+
}
670+
671+
class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> : ROCDL_IntrOp<mnemonic,
672+
[0], [1, 5], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
673+
Arguments<(ins
674+
DefaultValuedAttr<I1Attr, "0">:$signA,
675+
LLVM_ScalarOrVectorOf<AB>:$A,
676+
DefaultValuedAttr<I1Attr, "0">:$signB,
677+
LLVM_ScalarOrVectorOf<AB>:$B,
678+
DefaultValuedAttr<I16Attr, "0">:$modC,
679+
LLVM_ScalarOrVectorOf<C>:$C,
680+
DefaultValuedAttr<I1Attr, "0">:$reuseA,
681+
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
682+
let results = (outs LLVM_ScalarOrVectorOf<D>:$res);
683+
let assemblyFormat = [{
684+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
685+
}];
686+
}
687+
688+
class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
689+
[0], [1], [], 1, 0, 0, 0, [0, 2, 5, 6], ["signA", "signB", "reuseA","reuseB"]>,
690+
Arguments<(ins
691+
DefaultValuedAttr<I1Attr, "0">:$signA,
692+
LLVM_ScalarOrVectorOf<AB>:$A,
693+
DefaultValuedAttr<I1Attr, "0">:$signB,
694+
LLVM_ScalarOrVectorOf<AB>:$B,
695+
LLVM_ScalarOrVectorOf<CD>:$C,
696+
DefaultValuedAttr<I1Attr, "0">:$reuseA,
697+
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
698+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
699+
let assemblyFormat = [{
700+
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
701+
}];
605702
}
606703

607704
// Available from gfx11
608-
def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
609-
def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
610-
def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
611-
def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
612-
def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
613-
def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
705+
def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.f16", /*Type AB=*/F16, /*Type CD=*/F32>;
706+
def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf16", AnyInteger, F32>;
707+
def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.f16.16x16x16.f16", F16, F16>;
708+
def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.bf16.16x16x16.bf16", AnyInteger, AnyInteger>;
709+
def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu8", AnyInteger, AnyInteger>;
710+
def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu4", AnyInteger, AnyInteger>;
614711
// Available from gfx12
615-
def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
616-
def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>;
617-
def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
618-
def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
619-
def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
712+
def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_fp8", AnyInteger, F32>;
713+
def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_bf8", AnyInteger, F32>;
714+
def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_bf8", AnyInteger, F32>;
715+
def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_fp8", AnyInteger, F32>;
716+
def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x32.iu4", AnyInteger, AnyInteger>;
620717
// Available from gfx1250
621-
def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>;
622-
def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>;
623-
def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>;
624-
def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>;
625-
def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>;
626-
def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;
627-
def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>;
628-
def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>;
629-
def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>;
630-
def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>;
631-
def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>;
632-
def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>;
633-
def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>;
634-
def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>;
635-
def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>;
636-
def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>;
637-
def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>;
638-
def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>;
639-
def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>;
640-
def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>;
641-
def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
642-
def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
643-
def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
718+
def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x4.f32", F32, F32>;
719+
def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.bf16", BF16, F32>;
720+
def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.f16", F16, F32>;
721+
def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f16.16x16x32.f16", F16, F16>;
722+
def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.bf16.16x16x32.bf16", BF16, BF16>;
723+
def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Diff_IntrOp<"wmma.bf16f32.16x16x32.bf16", BF16, /*Type C=*/F32, /*Type D=*/BF16>;
724+
def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_fp8", AnyInteger, F32>;
725+
def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_bf8", AnyInteger, F32>;
726+
def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_fp8", AnyInteger, F32>;
727+
def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_bf8", AnyInteger, F32>;
728+
def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_fp8", AnyInteger, F16>;
729+
def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_bf8", AnyInteger, F16>;
730+
def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_fp8", AnyInteger, F16>;
731+
def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_bf8", AnyInteger, F16>;
732+
def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_fp8", AnyInteger, F32>;
733+
def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_bf8", AnyInteger, F32>;
734+
def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_fp8", AnyInteger, F32>;
735+
def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_bf8", AnyInteger, F32>;
736+
def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_fp8", AnyInteger, F16>;
737+
def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_bf8", AnyInteger, F16>;
738+
def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_fp8", AnyInteger, F16>;
739+
def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16>;
740+
def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp<"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger>;
644741

645742
//===---------------------------------------------------------------------===//
646743
// LDS transpose intrinsics (available in GFX950)

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1717
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1818
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
19+
#include "mlir/IR/Attributes.h"
1920
#include "mlir/IR/BuiltinAttributes.h"
2021
#include "mlir/IR/BuiltinTypes.h"
2122
#include "mlir/IR/TypeUtilities.h"
@@ -79,12 +80,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter,
7980
return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
8081
}
8182

82-
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
83-
bool value) {
84-
Type llvmI1 = rewriter.getI1Type();
85-
return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
86-
}
87-
8883
/// Returns the linear index used to access an element in the memref.
8984
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
9085
Location loc, MemRefDescriptor &memRefDescriptor,
@@ -684,23 +679,18 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
684679
/// intrinsics having been defined before the AMD backend supported bfloat. We
685680
/// similarly need to pack 8-bit float types into integers as if they were i8
686681
/// (which they are for the backend's purposes).
687-
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
688-
Location loc,
689-
const TypeConverter *typeConverter,
690-
bool isUnsigned, Value llvmInput,
691-
Value mlirInput,
692-
SmallVector<Value, 4> &operands) {
682+
static void wmmaPushInputOperand(
683+
ConversionPatternRewriter &rewriter, Location loc,
684+
const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
685+
Value mlirInput, SmallVectorImpl<Value> &operands,
686+
SmallVectorImpl<NamedAttribute> &attrs, StringRef attrName) {
693687
Type inputType = llvmInput.getType();
694688
auto vectorType = dyn_cast<VectorType>(inputType);
695689
if (!vectorType) {
696690
operands.push_back(llvmInput);
697691
return;
698692
}
699693
Type elemType = vectorType.getElementType();
700-
701-
if (elemType.isBF16())
702-
llvmInput = LLVM::BitcastOp::create(
703-
rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
704694
if (elemType.getIntOrFloatBitWidth() > 8) {
705695
operands.push_back(llvmInput);
706696
return;
@@ -719,8 +709,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
719709
} else if (elemType.isSignedInteger()) {
720710
localIsUnsigned = false;
721711
}
722-
Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
723-
operands.push_back(sign);
712+
attrs.push_back(
713+
NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
724714
}
725715

726716
int64_t numBits =
@@ -751,18 +741,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
751741
Location loc,
752742
const TypeConverter *typeConverter,
753743
Value output, int32_t subwordOffset,
754-
bool clamp, SmallVector<Value, 4> &operands) {
744+
bool clamp, SmallVectorImpl<Value> &operands,
745+
SmallVectorImpl<NamedAttribute> &attrs) {
755746
Type inputType = output.getType();
756747
auto vectorType = dyn_cast<VectorType>(inputType);
757748
Type elemType = vectorType.getElementType();
758-
if (elemType.isBF16())
759-
output = LLVM::BitcastOp::create(
760-
rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
761749
operands.push_back(output);
762750
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
763-
operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
751+
attrs.push_back(
752+
NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
764753
} else if (elemType.isInteger(32)) {
765-
operands.push_back(createI1Constant(rewriter, loc, clamp));
754+
attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
766755
}
767756
}
768757

@@ -1311,11 +1300,33 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
13111300
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
13121301
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
13131302

1314-
// The WMMA operations represent vectors of bf16s as vectors of i16s, so we
1315-
// need to bitcast bfloats to i16 and then bitcast them back.
1303+
bool isGFX1250 = chipset >= Chipset(12, 5, 0);
1304+
1305+
// The WMMA operations represent vectors of bf16s as vectors of i16s
1306+
// (except on gfx1250), so we need to bitcast bfloats to i16 and then
1307+
// bitcast them back.
1308+
auto aType = cast<VectorType>(adaptor.getSourceA().getType());
1309+
auto bType = cast<VectorType>(adaptor.getSourceB().getType());
1310+
auto destCType = cast<VectorType>(adaptor.getDestC().getType());
1311+
bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1312+
bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1313+
bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1314+
bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
13161315
VectorType rawOutType = outType;
1317-
if (outType.getElementType().isBF16())
1316+
if (castOutToI16)
13181317
rawOutType = outType.clone(rewriter.getI16Type());
1318+
Value a = adaptor.getSourceA();
1319+
if (castAToI16)
1320+
a = LLVM::BitcastOp::create(rewriter, loc,
1321+
aType.clone(rewriter.getI16Type()), a);
1322+
Value b = adaptor.getSourceB();
1323+
if (castBToI16)
1324+
b = LLVM::BitcastOp::create(rewriter, loc,
1325+
bType.clone(rewriter.getI16Type()), b);
1326+
Value destC = adaptor.getDestC();
1327+
if (castDestCToI16)
1328+
destC = LLVM::BitcastOp::create(
1329+
rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
13191330

13201331
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
13211332

@@ -1325,18 +1336,20 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
13251336
if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
13261337
return op.emitOpError("subwordOffset not supported on gfx12+");
13271338

1328-
OperationState loweredOp(loc, *maybeIntrinsic);
1329-
loweredOp.addTypes(rawOutType);
1330-
13311339
SmallVector<Value, 4> operands;
1332-
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
1333-
adaptor.getSourceA(), op.getSourceA(), operands);
1334-
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
1335-
adaptor.getSourceB(), op.getSourceB(), operands);
1336-
wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
1337-
op.getSubwordOffset(), op.getClamp(), operands);
1340+
SmallVector<NamedAttribute, 4> attrs;
1341+
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
1342+
op.getSourceA(), operands, attrs, "signA");
1343+
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
1344+
op.getSourceB(), operands, attrs, "signB");
1345+
wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
1346+
op.getSubwordOffset(), op.getClamp(), operands,
1347+
attrs);
13381348

1349+
OperationState loweredOp(loc, *maybeIntrinsic);
1350+
loweredOp.addTypes(rawOutType);
13391351
loweredOp.addOperands(operands);
1352+
loweredOp.addAttributes(attrs);
13401353
Operation *lowered = rewriter.create(loweredOp);
13411354

13421355
Operation *maybeCastBack = lowered;

0 commit comments

Comments
 (0)