-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][ROCDL] Refactor wmma intrinsics to use attributes not operands where possible #167041
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ROCDL] Refactor wmma intrinsics to use attributes not operands where possible #167041
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir-llvm Author: Muzammiluddin Syed (Muzammiluddin-Syed-ECE) ChangesThe current implementation of the WMMA intrinsic ops as they are defined in the ROCDL tablegen is incorrect. They represent as operands what should be attributes such as Patch is 59.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167041.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 5241f9a6f2b43..3cc906030b1a4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -582,51 +582,148 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
//===---------------------------------------------------------------------===//
// WMMA intrinsics
-class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
- list<Trait> traits = []> :
- ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
- Arguments<(ins Variadic<LLVM_Type>:$args)> {
- let assemblyFormat =
- "$args attr-dict `:` functional-type($args, $res)";
+class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+ [0], [0], [], 1, 0, 0, 0, [], []>,
+ Arguments<(ins
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ LLVM_ScalarOrVectorOf<CD>:$C)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
+}
+
+class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+ [0], [1], [], 1, 0, 0, 0, [3], ["opsel"]>,
+ Arguments<(ins
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ LLVM_ScalarOrVectorOf<CD>:$C,
+ DefaultValuedAttr<I1Attr, "0">:$opsel)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
+}
+
+class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+ [0], [1], [], 1, 0, 0, 0, [0, 2, 5], ["signA", "signB", "clamp"]>,
+ Arguments<(ins
+ DefaultValuedAttr<I1Attr, "0">:$signA,
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ DefaultValuedAttr<I1Attr, "0">:$signB,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ LLVM_ScalarOrVectorOf<CD>:$C,
+ DefaultValuedAttr<I1Attr, "0">:$clamp)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
+}
+
+class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+ [0], [1], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
+ Arguments<(ins
+ DefaultValuedAttr<I1Attr, "0">:$signA,
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ DefaultValuedAttr<I1Attr, "0">:$signB,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ DefaultValuedAttr<I16Attr, "0">:$modC,
+ LLVM_ScalarOrVectorOf<CD>:$C,
+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
+}
+
+class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+ [0], [0], [], 1, 0, 0, 0, [2, 4, 5], ["modC","reuseA","reuseB"]>,
+ Arguments<(ins
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ DefaultValuedAttr<I16Attr, "0">:$modC,
+ LLVM_ScalarOrVectorOf<CD>:$C,
+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
+}
+
+class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> : ROCDL_IntrOp<mnemonic,
+ [0], [1, 5], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
+ Arguments<(ins
+ DefaultValuedAttr<I1Attr, "0">:$signA,
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ DefaultValuedAttr<I1Attr, "0">:$signB,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ DefaultValuedAttr<I16Attr, "0">:$modC,
+ LLVM_ScalarOrVectorOf<C>:$C,
+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+ let results = (outs LLVM_ScalarOrVectorOf<D>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
+}
+
+class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+ [0], [1], [], 1, 0, 0, 0, [0, 2, 5, 6], ["signA", "signB", "reuseA","reuseB"]>,
+ Arguments<(ins
+ DefaultValuedAttr<I1Attr, "0">:$signA,
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ DefaultValuedAttr<I1Attr, "0">:$signB,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ LLVM_ScalarOrVectorOf<CD>:$C,
+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
}
// Available from gfx11
-def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
-def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
-def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
-def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
-def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
-def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
+def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.f16", /*Type AB=*/F16, /*Type CD=*/F32>;
+def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf16", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.f16.16x16x16.f16", F16, F16>;
+def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.bf16.16x16x16.bf16", AnyInteger, AnyInteger>;
+def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu8", AnyInteger, AnyInteger>;
+def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu4", AnyInteger, AnyInteger>;
// Available from gfx12
-def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
-def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>;
-def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
-def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
-def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
+def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x32.iu4", AnyInteger, AnyInteger>;
// Available from gfx1250
-def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>;
-def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>;
-def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>;
-def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>;
-def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>;
-def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;
-def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
-def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
+def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x4.f32", F32, F32>;
+def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.bf16", BF16, F32>;
+def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.f16", F16, F32>;
+def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f16.16x16x32.f16", F16, F16>;
+def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.bf16.16x16x32.bf16", BF16, BF16>;
+def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Diff_IntrOp<"wmma.bf16f32.16x16x32.bf16", BF16, /*Type C=*/F32, /*Type D=*/BF16>;
+def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp<"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger>;
//===---------------------------------------------------------------------===//
// LDS transpose intrinsics (available in GFX950)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3a307a0756d93..996cf445a15d0 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
@@ -79,12 +80,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter,
return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
}
-static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
- bool value) {
- Type llvmI1 = rewriter.getI1Type();
- return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
-}
-
/// Returns the linear index used to access an element in the memref.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
Location loc, MemRefDescriptor &memRefDescriptor,
@@ -684,12 +679,11 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
/// intrinsics having been defined before the AMD backend supported bfloat. We
/// similarly need to pack 8-bit float types into integers as if they were i8
/// (which they are for the backend's purposes).
-static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
- Location loc,
- const TypeConverter *typeConverter,
- bool isUnsigned, Value llvmInput,
- Value mlirInput,
- SmallVector<Value, 4> &operands) {
+static void wmmaPushInputOperand(
+ ConversionPatternRewriter &rewriter, Location loc,
+ const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
+ Value mlirInput, SmallVector<Value, 4> &operands,
+ SmallVector<NamedAttribute, 4> &attrs, StringRef attrName) {
Type inputType = llvmInput.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
if (!vectorType) {
@@ -697,10 +691,6 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
return;
}
Type elemType = vectorType.getElementType();
-
- if (elemType.isBF16())
- llvmInput = LLVM::BitcastOp::create(
- rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
if (elemType.getIntOrFloatBitWidth() > 8) {
operands.push_back(llvmInput);
return;
@@ -719,8 +709,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
} else if (elemType.isSignedInteger()) {
localIsUnsigned = false;
}
- Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
- operands.push_back(sign);
+ attrs.push_back(NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
}
int64_t numBits =
@@ -751,18 +740,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
Value output, int32_t subwordOffset,
- bool clamp, SmallVector<Value, 4> &operands) {
+ bool clamp, SmallVector<Value, 4> &operands,
+ SmallVector<NamedAttribute, 4> &attrs) {
Type inputType = output.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
- if (elemType.isBF16())
- output = LLVM::BitcastOp::create(
- rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
operands.push_back(output);
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
- operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
+ attrs.push_back(
+ NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
} else if (elemType.isInteger(32)) {
- operands.push_back(createI1Constant(rewriter, loc, clamp));
+ attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
}
}
@@ -1302,6 +1290,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
LogicalResult
matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+
Location loc = op.getLoc();
auto outType =
typeConverter->convertType<VectorType>(op.getDestD().getType());
@@ -1311,11 +1300,33 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
- // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
- // need to bitcast bfloats to i16 and then bitcast them back.
+ bool isGFX1250 = chipset == Chipset(12, 5, 0);
+
+ // The WMMA operations represent vectors of bf16s as vectors of i16s
+ // (except on gfx1250), so we need to bitcast bfloats to i16 and then
+ // bitcast them back.
+ auto aType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
+ auto bType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
+ auto destCType = dyn_cast<VectorType>(adaptor.getDestC().getType());
+ bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
+ bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
+ bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
+ bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
VectorType rawOutType = outType;
- if (outType.getElementType().isBF16())
+ if (castOutToI16)
rawOutType = outType.clone(rewriter.getI16Type());
+ Value a = adaptor.getSourceA();
+ if (castAToI16)
+ a = LLVM::BitcastOp::create(rewriter, loc,
+ aType.clone(rewriter.getI16Type()), a);
+ Value b = adaptor.getSourceB();
+ if (castBToI16)
+ b = LLVM::BitcastOp::create(rewriter, loc,
+ bType.clone(rewriter.getI16Type()), b);
+ Value destC = adaptor.getDestC();
+ if (castDestCToI16)
+ destC = LLVM::BitcastOp::create(
+ rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
@@ -1325,21 +1336,23 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
return op.emitOpError("subwordOffset not supported on gfx12+");
- OperationState loweredOp(loc, *maybeIntrinsic);
- loweredOp.addTypes(rawOutType);
-
SmallVector<Value, 4> operands;
- wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
- adaptor.getSourceA(), op.getSourceA(), operands);
- wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
- adaptor.getSourceB(), op.getSourceB(), operands);
- wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
- op.getSubwordOffset(), op.getClamp(), operands);
+ SmallVector<NamedAttribute, 4> attrs;
+ wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
+ op.getSourceA(), operands, attrs, "signA");
+ wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
+ op.getSourceB(), operands, attrs, "signB");
+ wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
+ op.getSubwordOffset(), op.getClamp(), operands,
+ attrs);
+ OperationState loweredOp(loc, *maybeIntrinsic);
+ loweredOp.addTypes(rawOutType);
loweredOp.addOpe...
[truncated]
|
|
@llvm/pr-subscribers-mlir Author: Muzammiluddin Syed (Muzammiluddin-Syed-ECE) ChangesThe current implementation of the WMMA intrinsic ops as they are defined in the ROCDL tablegen is incorrect. They represent as operands what should be attributes such as Patch is 59.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167041.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 5241f9a6f2b43..3cc906030b1a4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -582,51 +582,148 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
//===---------------------------------------------------------------------===//
// WMMA intrinsics
-class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
- list<Trait> traits = []> :
- ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
- Arguments<(ins Variadic<LLVM_Type>:$args)> {
- let assemblyFormat =
- "$args attr-dict `:` functional-type($args, $res)";
+class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+ [0], [0], [], 1, 0, 0, 0, [], []>,
+ Arguments<(ins
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ LLVM_ScalarOrVectorOf<CD>:$C)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
+}
+
+class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+ [0], [1], [], 1, 0, 0, 0, [3], ["opsel"]>,
+ Arguments<(ins
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ LLVM_ScalarOrVectorOf<CD>:$C,
+ DefaultValuedAttr<I1Attr, "0">:$opsel)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
+}
+
+class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+ [0], [1], [], 1, 0, 0, 0, [0, 2, 5], ["signA", "signB", "clamp"]>,
+ Arguments<(ins
+ DefaultValuedAttr<I1Attr, "0">:$signA,
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ DefaultValuedAttr<I1Attr, "0">:$signB,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ LLVM_ScalarOrVectorOf<CD>:$C,
+ DefaultValuedAttr<I1Attr, "0">:$clamp)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
+}
+
+class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+ [0], [1], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
+ Arguments<(ins
+ DefaultValuedAttr<I1Attr, "0">:$signA,
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ DefaultValuedAttr<I1Attr, "0">:$signB,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ DefaultValuedAttr<I16Attr, "0">:$modC,
+ LLVM_ScalarOrVectorOf<CD>:$C,
+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
+}
+
+class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+ [0], [0], [], 1, 0, 0, 0, [2, 4, 5], ["modC","reuseA","reuseB"]>,
+ Arguments<(ins
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ DefaultValuedAttr<I16Attr, "0">:$modC,
+ LLVM_ScalarOrVectorOf<CD>:$C,
+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
+}
+
+class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> : ROCDL_IntrOp<mnemonic,
+ [0], [1, 5], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
+ Arguments<(ins
+ DefaultValuedAttr<I1Attr, "0">:$signA,
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ DefaultValuedAttr<I1Attr, "0">:$signB,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ DefaultValuedAttr<I16Attr, "0">:$modC,
+ LLVM_ScalarOrVectorOf<C>:$C,
+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+ let results = (outs LLVM_ScalarOrVectorOf<D>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
+}
+
+class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
+ [0], [1], [], 1, 0, 0, 0, [0, 2, 5, 6], ["signA", "signB", "reuseA","reuseB"]>,
+ Arguments<(ins
+ DefaultValuedAttr<I1Attr, "0">:$signA,
+ LLVM_ScalarOrVectorOf<AB>:$A,
+ DefaultValuedAttr<I1Attr, "0">:$signB,
+ LLVM_ScalarOrVectorOf<AB>:$B,
+ LLVM_ScalarOrVectorOf<CD>:$C,
+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
+ let assemblyFormat = [{
+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
+ }];
}
// Available from gfx11
-def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
-def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
-def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
-def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
-def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
-def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
+def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.f16", /*Type AB=*/F16, /*Type CD=*/F32>;
+def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf16", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.f16.16x16x16.f16", F16, F16>;
+def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_WMMA_Opsel_IntrOp<"wmma.bf16.16x16x16.bf16", AnyInteger, AnyInteger>;
+def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu8", AnyInteger, AnyInteger>;
+def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x16.iu4", AnyInteger, AnyInteger>;
// Available from gfx12
-def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
-def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_bf8", [1]>;
-def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
-def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
-def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
+def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_WMMA_IntrOp<"wmma.f32.16x16x16.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_WMMA_IU_IntrOp<"wmma.i32.16x16x32.iu4", AnyInteger, AnyInteger>;
// Available from gfx1250
-def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>;
-def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>;
-def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>;
-def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>;
-def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>;
-def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;
-def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>;
-def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>;
-def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>;
-def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
-def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
-def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
+def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x4.f32", F32, F32>;
+def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.bf16", BF16, F32>;
+def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f32.16x16x32.f16", F16, F32>;
+def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.f16.16x16x32.f16", F16, F16>;
+def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp<"wmma.bf16.16x16x32.bf16", BF16, BF16>;
+def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Diff_IntrOp<"wmma.bf16f32.16x16x32.bf16", BF16, /*Type C=*/F32, /*Type D=*/BF16>;
+def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x64.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.fp8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x64.bf8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.fp8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_fp8", AnyInteger, F32>;
+def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f32.16x16x128.bf8_bf8", AnyInteger, F32>;
+def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.fp8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_fp8", AnyInteger, F16>;
+def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16>;
+def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp<"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger>;
//===---------------------------------------------------------------------===//
// LDS transpose intrinsics (available in GFX950)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3a307a0756d93..996cf445a15d0 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
@@ -79,12 +80,6 @@ static Value createI64Constant(ConversionPatternRewriter &rewriter,
return LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), value);
}
-static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
- bool value) {
- Type llvmI1 = rewriter.getI1Type();
- return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
-}
-
/// Returns the linear index used to access an element in the memref.
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
Location loc, MemRefDescriptor &memRefDescriptor,
@@ -684,12 +679,11 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
/// intrinsics having been defined before the AMD backend supported bfloat. We
/// similarly need to pack 8-bit float types into integers as if they were i8
/// (which they are for the backend's purposes).
-static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
- Location loc,
- const TypeConverter *typeConverter,
- bool isUnsigned, Value llvmInput,
- Value mlirInput,
- SmallVector<Value, 4> &operands) {
+static void wmmaPushInputOperand(
+ ConversionPatternRewriter &rewriter, Location loc,
+ const TypeConverter *typeConverter, bool isUnsigned, Value llvmInput,
+ Value mlirInput, SmallVector<Value, 4> &operands,
+ SmallVector<NamedAttribute, 4> &attrs, StringRef attrName) {
Type inputType = llvmInput.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
if (!vectorType) {
@@ -697,10 +691,6 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
return;
}
Type elemType = vectorType.getElementType();
-
- if (elemType.isBF16())
- llvmInput = LLVM::BitcastOp::create(
- rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
if (elemType.getIntOrFloatBitWidth() > 8) {
operands.push_back(llvmInput);
return;
@@ -719,8 +709,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
} else if (elemType.isSignedInteger()) {
localIsUnsigned = false;
}
- Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
- operands.push_back(sign);
+ attrs.push_back(NamedAttribute(attrName, rewriter.getBoolAttr(!localIsUnsigned)));
}
int64_t numBits =
@@ -751,18 +740,17 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
Value output, int32_t subwordOffset,
- bool clamp, SmallVector<Value, 4> &operands) {
+ bool clamp, SmallVector<Value, 4> &operands,
+ SmallVector<NamedAttribute, 4> &attrs) {
Type inputType = output.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
- if (elemType.isBF16())
- output = LLVM::BitcastOp::create(
- rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
operands.push_back(output);
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
- operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
+ attrs.push_back(
+ NamedAttribute("opsel", rewriter.getBoolAttr(subwordOffset)));
} else if (elemType.isInteger(32)) {
- operands.push_back(createI1Constant(rewriter, loc, clamp));
+ attrs.push_back(NamedAttribute("clamp", rewriter.getBoolAttr(clamp)));
}
}
@@ -1302,6 +1290,7 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
LogicalResult
matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+
Location loc = op.getLoc();
auto outType =
typeConverter->convertType<VectorType>(op.getDestD().getType());
@@ -1311,11 +1300,33 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
- // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
- // need to bitcast bfloats to i16 and then bitcast them back.
+ bool isGFX1250 = chipset == Chipset(12, 5, 0);
+
+ // The WMMA operations represent vectors of bf16s as vectors of i16s
+ // (except on gfx1250), so we need to bitcast bfloats to i16 and then
+ // bitcast them back.
+ auto aType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
+ auto bType = dyn_cast<VectorType>(adaptor.getSourceA().getType());
+ auto destCType = dyn_cast<VectorType>(adaptor.getDestC().getType());
+ bool castAToI16 = aType.getElementType().isBF16() && !isGFX1250;
+ bool castBToI16 = bType.getElementType().isBF16() && !isGFX1250;
+ bool castDestCToI16 = destCType.getElementType().isBF16() && !isGFX1250;
+ bool castOutToI16 = outType.getElementType().isBF16() && !isGFX1250;
VectorType rawOutType = outType;
- if (outType.getElementType().isBF16())
+ if (castOutToI16)
rawOutType = outType.clone(rewriter.getI16Type());
+ Value a = adaptor.getSourceA();
+ if (castAToI16)
+ a = LLVM::BitcastOp::create(rewriter, loc,
+ aType.clone(rewriter.getI16Type()), a);
+ Value b = adaptor.getSourceB();
+ if (castBToI16)
+ b = LLVM::BitcastOp::create(rewriter, loc,
+ bType.clone(rewriter.getI16Type()), b);
+ Value destC = adaptor.getDestC();
+ if (castDestCToI16)
+ destC = LLVM::BitcastOp::create(
+ rewriter, loc, destCType.clone(rewriter.getI16Type()), destC);
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
@@ -1325,21 +1336,23 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
if (chipset.majorVersion >= 12 && op.getSubwordOffset() != 0)
return op.emitOpError("subwordOffset not supported on gfx12+");
- OperationState loweredOp(loc, *maybeIntrinsic);
- loweredOp.addTypes(rawOutType);
-
SmallVector<Value, 4> operands;
- wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
- adaptor.getSourceA(), op.getSourceA(), operands);
- wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
- adaptor.getSourceB(), op.getSourceB(), operands);
- wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
- op.getSubwordOffset(), op.getClamp(), operands);
+ SmallVector<NamedAttribute, 4> attrs;
+ wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), a,
+ op.getSourceA(), operands, attrs, "signA");
+ wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), b,
+ op.getSourceB(), operands, attrs, "signB");
+ wmmaPushOutputOperand(rewriter, loc, typeConverter, destC,
+ op.getSubwordOffset(), op.getClamp(), operands,
+ attrs);
+ OperationState loweredOp(loc, *maybeIntrinsic);
+ loweredOp.addTypes(rawOutType);
loweredOp.addOpe...
[truncated]
|
…perands where possible Signed-off-by: Muzammiluddin Syed <[email protected]>
aedef1f to
774e166
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
774e166 to
16f5116
Compare
Signed-off-by: Muzammiluddin Syed <[email protected]>
16f5116 to
fef3ace
Compare
- Use SmallVectorImpl for output parameters per LLVM coding standards - Fix type casting bug (getSourceA -> getSourceB on line 1309) - Use cast<> instead of dyn_cast<> for guaranteed vector types - Simplify assembly format with functional-type shorthand Addresses review comments from @kuhar on PR llvm#167041
Update the Python test to match the new WMMA API where operands are passed as separate positional arguments and attributes (like opsel) are passed as keyword arguments with boolean values. Changes: - Changed from list-style arguments to positional arguments - Changed opsel from MLIR Value operand to Python bool attribute - Removed unnecessary false constant creation This fixes the CI test failure where the Python bindings were expecting the old API signature.
3bb9af3 to
b768b8a
Compare
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but wait for @krzysz00 to review as well
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's ensure we have (working, precise) tests for all the things that just got made attributes with their non-default values, otherwise LGTM here
04197af to
77eaa97
Compare
Signed-off-by: Muzammiluddin Syed <[email protected]>
77eaa97 to
21ea4d7
Compare
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now
The current implementation of the WMMA intrinsic ops as they are defined in the ROCDL tablegen is incorrect. They represent as operands what should be attributes such as
clamp,opsel,signA/signB. This change performs a refactoring to bring it in line with what we expect.