Skip to content

Commit fc0d7fb

Browse files
[mlir][ROCDL] refactor wmma intrinsics to use attributes instead of operands where possible
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 913849e commit fc0d7fb

File tree

6 files changed

+286
-179
lines changed

6 files changed

+286
-179
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 137 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -582,51 +582,148 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
582582

583583
//===---------------------------------------------------------------------===//
584584
// WMMA intrinsics
585-
class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
586-
list<Trait> traits = []> :
587-
ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
588-
Arguments<(ins Variadic<LLVM_Type>:$args)> {
589-
let assemblyFormat =
590-
"$args attr-dict `:` functional-type($args, $res)";
585+
class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
586+
[0], [0], [], 1, 0, 0, 0, [], []>,
587+
Arguments<(ins
588+
LLVM_ScalarOrVectorOf<AB>:$A,
589+
LLVM_ScalarOrVectorOf<AB>:$B,
590+
LLVM_ScalarOrVectorOf<CD>:$C)> {
591+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
592+
let assemblyFormat = [{
593+
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
594+
}];
595+
}
596+
597+
class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
598+
[0], [1], [], 1, 0, 0, 0, [3], ["opsel"]>,
599+
Arguments<(ins
600+
LLVM_ScalarOrVectorOf<AB>:$A,
601+
LLVM_ScalarOrVectorOf<AB>:$B,
602+
LLVM_ScalarOrVectorOf<CD>:$C,
603+
DefaultValuedAttr<I1Attr, "0">:$opsel)> {
604+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
605+
let assemblyFormat = [{
606+
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
607+
}];
608+
}
609+
610+
class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
611+
[0], [1], [], 1, 0, 0, 0, [0, 2, 5], ["signA", "signB", "clamp"]>,
612+
Arguments<(ins
613+
DefaultValuedAttr<I1Attr, "0">:$signA,
614+
LLVM_ScalarOrVectorOf<AB>:$A,
615+
DefaultValuedAttr<I1Attr, "0">:$signB,
616+
LLVM_ScalarOrVectorOf<AB>:$B,
617+
LLVM_ScalarOrVectorOf<CD>:$C,
618+
DefaultValuedAttr<I1Attr, "0">:$clamp)> {
619+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
620+
let assemblyFormat = [{
621+
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
622+
}];
623+
}
624+
625+
class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
626+
[0], [1], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
627+
Arguments<(ins
628+
DefaultValuedAttr<I1Attr, "0">:$signA,
629+
LLVM_ScalarOrVectorOf<AB>:$A,
630+
DefaultValuedAttr<I1Attr, "0">:$signB,
631+
LLVM_ScalarOrVectorOf<AB>:$B,
632+
DefaultValuedAttr<I16Attr, "0">:$modC,
633+
LLVM_ScalarOrVectorOf<CD>:$C,
634+
DefaultValuedAttr<I1Attr, "0">:$reuseA,
635+
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
636+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
637+
let assemblyFormat = [{
638+
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
639+
}];
640+
}
641+
642+
class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
643+
[0], [0], [], 1, 0, 0, 0, [2, 4, 5], ["modC","reuseA","reuseB"]>,
644+
Arguments<(ins
645+
LLVM_ScalarOrVectorOf<AB>:$A,
646+
LLVM_ScalarOrVectorOf<AB>:$B,
647+
DefaultValuedAttr<I16Attr, "0">:$modC,
648+
LLVM_ScalarOrVectorOf<CD>:$C,
649+
DefaultValuedAttr<I1Attr, "0">:$reuseA,
650+
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
651+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
652+
let assemblyFormat = [{
653+
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
654+
}];
655+
}
656+
657+
class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> : ROCDL_IntrOp<mnemonic,
658+
[0], [1, 5], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
659+
Arguments<(ins
660+
DefaultValuedAttr<I1Attr, "0">:$signA,
661+
LLVM_ScalarOrVectorOf<AB>:$A,
662+
DefaultValuedAttr<I1Attr, "0">:$signB,
663+
LLVM_ScalarOrVectorOf<AB>:$B,
664+
DefaultValuedAttr<I16Attr, "0">:$modC,
665+
LLVM_ScalarOrVectorOf<C>:$C,
666+
DefaultValuedAttr<I1Attr, "0">:$reuseA,
667+
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
668+
let results = (outs LLVM_ScalarOrVectorOf<D>:$res);
669+
let assemblyFormat = [{
670+
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
671+
}];
672+
}
673+
674+
class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
675+
[0], [1], [], 1, 0, 0, 0, [0, 2, 5, 6], ["signA", "signB", "reuseA","reuseB"]>,
676+
Arguments<(ins
677+
DefaultValuedAttr<I1Attr, "0">:$signA,
678+
LLVM_ScalarOrVectorOf<AB>:$A,
679+
DefaultValuedAttr<I1Attr, "0">:$signB,
680+
LLVM_ScalarOrVectorOf<AB>:$B,
681+
LLVM_ScalarOrVectorOf<CD>:$C,
682+
DefaultValuedAttr<I1Attr, "0">:$reuseA,
683+
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
684+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
685+
let assemblyFormat = [{
686+
$A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
687+
}];
591688
}
592689

593690
// Available from gfx11
594-
def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
595-
def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
596-
def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
597-
def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
598-
def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
599-
def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
691+
def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.f16", /*Type AB=*/F16, /*Type CD=*/F32>;
692+
def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf16", AnyInteger, F32>;
693+
def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.f16.16x16x16.f16", F16, F16>;
694+
def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.bf16.16x16x16.bf16", AnyInteger, AnyInteger>;
695+
def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu8", AnyInteger, AnyInteger>;
696+
def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu4", AnyInteger, AnyInteger>;
600697
// Available from gfx12
601-
def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
602-
def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>;
603-
def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
604-
def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
605-
def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
698+
def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_fp8", AnyInteger, F32>;
699+
def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_bf8", AnyInteger, F32>;
700+
def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_bf8", AnyInteger, F32>;
701+
def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_fp8", AnyInteger, F32>;
702+
def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x32.iu4", AnyInteger, AnyInteger>;
606703
// Available from gfx1250
607-
def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>;
608-
def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>;
609-
def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>;
610-
def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>;
611-
def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>;
612-
def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;
613-
def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>;
614-
def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>;
615-
def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>;
616-
def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>;
617-
def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>;
618-
def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>;
619-
def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>;
620-
def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>;
621-
def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>;
622-
def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>;
623-
def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>;
624-
def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>;
625-
def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>;
626-
def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>;
627-
def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
628-
def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
629-
def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
704+
def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x4.f32", F32, F32>;
705+
def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.bf16", BF16, F32>;
706+
def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.f16", F16, F32>;
707+
def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f16.16x16x32.f16", F16, F16>;
708+
def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.bf16.16x16x32.bf16", BF16, BF16>;
709+
def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Diff_IntrOp<"wmma.bf16f32.16x16x32.bf16", BF16, /*Type C=*/F32, /*Type D=*/BF16>;
710+
def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_fp8", AnyInteger, F32>;
711+
def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_bf8", AnyInteger, F32>;
712+
def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_fp8", AnyInteger, F32>;
713+
def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_bf8", AnyInteger, F32>;
714+
def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_fp8", AnyInteger, F16>;
715+
def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_bf8", AnyInteger, F16>;
716+
def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_fp8", AnyInteger, F16>;
717+
def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_bf8", AnyInteger, F16>;
718+
def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_fp8", AnyInteger, F32>;
719+
def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_bf8", AnyInteger, F32>;
720+
def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_fp8", AnyInteger, F32>;
721+
def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_bf8", AnyInteger, F32>;
722+
def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_fp8", AnyInteger, F16>;
723+
def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_bf8", AnyInteger, F16>;
724+
def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_fp8", AnyInteger, F16>;
725+
def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16>;
726+
def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp<"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger>;
630727

631728
//===---------------------------------------------------------------------===//
632729
// LDS transpose intrinsics (available in GFX950)

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1717
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1818
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
19+
#include "mlir/IR/Attributes.h"
1920
#include "mlir/IR/BuiltinAttributes.h"
2021
#include "mlir/IR/BuiltinTypes.h"
2122
#include "mlir/IR/TypeUtilities.h"
@@ -79,12 +80,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter,
7980
return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
8081
}
8182

82-
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
83-
bool value) {
84-
Type llvmI1 = rewriter.getI1Type();
85-
return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
86-
}
87-
8883
/// Returns the linear index used to access an element in the memref.
8984
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
9085
Location loc, MemRefDescriptor &memRefDescriptor,
@@ -684,23 +679,18 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
684679
/// intrinsics having been defined before the AMD backend supported bfloat. We
685680
/// similarly need to pack 8-bit float types into integers as if they were i8
686681
/// (which they are for the backend's purposes).
687-
static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
688-
Location loc,
689-
const TypeConverter *typeConverter,
690-
bool isUnsigned, Value llvmInput,
691-
Value mlirInput,
692-
SmallVector<Value, 4> &operands) {
682+
static void wmmaPushInputOperand(
683+
ConversionPatternRewriter &rewriter, Location loc,
684+
const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
685+
Value mlirInput, SmallVector<Value, 4> &operands,
686+
SmallVector<NamedAttribute, 4> &attrs, StringRef attrName) {
693687
Type inputType = llvmInput.getType();
694688
auto vectorType = dyn_cast<VectorType>(inputType);
695689
if (!vectorType) {
696690
operands.push_back(llvmInput);
697691
return;
698692
}
699693
Type elemType = vectorType.getElementType();
700-
701-
if (elemType.isBF16())
702-
llvmInput = LLVM::BitcastOp::create(
703-
rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
704694
if (elemType.getIntOrFloatBitWidth() > 8) {
705695
operands.push_back(llvmInput);
706696
return;
@@ -719,8 +709,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
719709
} else if (elemType.isSignedInteger()) {
720710
localIsUnsigned = false;
721711
}
722-
Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
723-
operands.push_back(sign);
712+
attrs.push_back(NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
724713
}
725714

726715
int64_t numBits =
@@ -751,18 +740,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
751740
Location loc,
752741
const TypeConverter *typeConverter,
753742
Value output, int32_t subwordOffset,
754-
bool clamp, SmallVector<Value, 4> &operands) {
743+
bool clamp, SmallVector<Value, 4> &operands,
744+
SmallVector<NamedAttribute, 4> &attrs) {
755745
Type inputType = output.getType();
756746
auto vectorType = dyn_cast<VectorType>(inputType);
757747
Type elemType = vectorType.getElementType();
758-
if (elemType.isBF16())
759-
output = LLVM::BitcastOp::create(
760-
rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
761748
operands.push_back(output);
762749
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
763-
operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
750+
attrs.push_back(
751+
NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
764752
} else if (elemType.isInteger(32)) {
765-
operands.push_back(createI1Constant(rewriter, loc, clamp));
753+
attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
766754
}
767755
}
768756

@@ -1302,6 +1290,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
13021290
LogicalResult
13031291
matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
13041292
ConversionPatternRewriter &rewriter) const override {
1293+
13051294
Location loc = op.getLoc();
13061295
auto outType =
13071296
typeConverter->convertType<VectorType>(op.getDestD().getType());
@@ -1311,34 +1300,56 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
13111300
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
13121301
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
13131302

1314-
// The WMMA operations represent vectors of bf16s as vectors of i16s, so we
1315-
// need to bitcast bfloats to i16 and then bitcast them back.
1303+
bool isGFX1250 = chipset == Chipset(12, 5, 0);
1304+
1305+
// The WMMA operations represent vectors of bf16s as vectors of i16s
1306+
// (except on gfx1250), so we need to bitcast bfloats to i16 and then
1307+
// bitcast them back.
1308+
auto aType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
1309+
auto bType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
1310+
auto destCType = dyn_cast<VectorType>(adaptor.getDestC().getType());
1311+
bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
1312+
bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
1313+
bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
1314+
bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
13161315
VectorType rawOutType = outType;
1317-
if (outType.getElementType().isBF16())
1316+
if (castOutToI16)
13181317
rawOutType = outType.clone(rewriter.getI16Type());
1318+
Value a = adaptor.getSourceA();
1319+
if (castAToI16)
1320+
a = LLVM::BitcastOp::create(rewriter, loc,
1321+
aType.clone(rewriter.getI16Type()), a);
1322+
Value b = adaptor.getSourceB();
1323+
if (castBToI16)
1324+
b = LLVM::BitcastOp::create(rewriter, loc,
1325+
bType.clone(rewriter.getI16Type()), b);
1326+
Value destC = adaptor.getDestC();
1327+
if (castDestCToI16)
1328+
destC = LLVM::BitcastOp::create(
1329+
rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
13191330

13201331
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
1321-
13221332
if (!maybeIntrinsic.has_value())
13231333
return op.emitOpError("no intrinsic matching WMMA on the given chipset");
13241334

13251335
if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
13261336
return op.emitOpError("subwordOffset not supported on gfx12+");
13271337

1328-
OperationState loweredOp(loc, *maybeIntrinsic);
1329-
loweredOp.addTypes(rawOutType);
1330-
13311338
SmallVector<Value, 4> operands;
1332-
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
1333-
adaptor.getSourceA(), op.getSourceA(), operands);
1334-
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
1335-
adaptor.getSourceB(), op.getSourceB(), operands);
1336-
wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
1337-
op.getSubwordOffset(), op.getClamp(), operands);
1339+
SmallVector<NamedAttribute, 4> attrs;
1340+
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
1341+
op.getSourceA(), operands, attrs, "signA");
1342+
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
1343+
op.getSourceB(), operands, attrs, "signB");
1344+
wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
1345+
op.getSubwordOffset(), op.getClamp(), operands,
1346+
attrs);
13381347

1348+
OperationState loweredOp(loc, *maybeIntrinsic);
1349+
loweredOp.addTypes(rawOutType);
13391350
loweredOp.addOperands(operands);
1351+
loweredOp.addAttributes(attrs);
13401352
Operation *lowered = rewriter.create(loweredOp);
1341-
13421353
Operation *maybeCastBack = lowered;
13431354
if (rawOutType != outType)
13441355
maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,

0 commit comments

Comments
 (0)