Skip to content

Commit 8d27a5c

Browse files
committed
Add Gemm+Elementwise+Gemm support
1 parent a64d2c2 commit 8d27a5c

34 files changed

+1753
-566
lines changed

mlir/include/mlir/Dialect/Rock/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mlir_tablegen(RockAccelTuningParamAttrInterface.h.inc -gen-attr-interface-decls)
2020
mlir_tablegen(RockAccelTuningParamAttrInterface.cpp.inc -gen-attr-interface-defs)
2121
add_public_tablegen_target(MLIRRockAccelTuningParamAttrInterfaceIncGen)
2222

23+
add_mlir_interface(RockGemmGemmWrapperInterface)
2324
add_mlir_interface(RockGemmWrapperInterface)
2425
add_mlir_interface(RockConvInterface)
2526
add_mlir_interface(RockAcceptingViewOpInterface)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===--------- GemmGemmSize.h - utility struct for gemm+gemm ----------===//
2+
//
3+
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines a utility struct, GemmGemmSize, that packages the sizes of
10+
// gemm+gemm to ensure a cleaner API.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_DIALECT_ROCK_IR_GEMMGEMMCONTEXT_H
15+
#define MLIR_DIALECT_ROCK_IR_GEMMGEMMCONTEXT_H
16+
17+
#include <cstdint>
18+
19+
namespace mlir {
20+
namespace rock {
21+
22+
/// Structure for holding the sizes of a matrix multiplication operation.
23+
struct GemmGemmSize {
24+
int64_t g;
25+
int64_t m;
26+
int64_t k;
27+
int64_t n;
28+
int64_t o;
29+
30+
GemmGemmSize(int64_t g, int64_t m, int64_t k, int64_t n, int64_t o)
31+
: g(g), m(m), k(k), n(n), o(o) {}
32+
33+
bool operator==(const GemmGemmSize &other) {
34+
return (g == other.g) && (m == other.m) && (k == other.k) &&
35+
(n == other.n) && (o == other.o);
36+
}
37+
};
38+
} // end namespace rock
39+
} // end namespace mlir
40+
#endif // MLIR_DIALECT_ROCK_IR_GEMMGEMMCONTEXT_H

mlir/include/mlir/Dialect/Rock/IR/Rock.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class PatternRewriter;
3838
#include "mlir/Dialect/Rock/IR/RockTypes.h"
3939

4040
#include "mlir/Dialect/Rock/IR/ConvolutionDims.h"
41+
#include "mlir/Dialect/Rock/IR/GemmGemmSize.h"
4142
#include "mlir/Dialect/Rock/IR/GemmSize.h"
4243

4344
namespace mlir {
@@ -49,12 +50,6 @@ class FusionRoot : public TraitBase<ConcreteType, FusionRoot> {};
4950
} // namespace OpTrait
5051
} // namespace mlir
5152

52-
// Following ifdef could be used to change
53-
// the attention operator to be a fused gemm-gemm
54-
// kernel for debugging purposes. This will also
55-
// adjust the test harness to verify the same as well
56-
// #define ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX
57-
5853
namespace mlir {
5954
namespace rock {
6055
//===----------------------------------------------------------------------===//
@@ -87,6 +82,7 @@ constexpr int64_t maxHardwareWorkgroupSize = 1024;
8782

8883
#include "mlir/Dialect/Rock/IR/RockAcceptingViewOpInterface.h"
8984
#include "mlir/Dialect/Rock/IR/RockConvInterface.h"
85+
#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
9086
#include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h"
9187
#include "mlir/Dialect/Rock/IR/RockWriterOpInterface.h"
9288

mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define ROCK_ATTRS
1111

1212
include "mlir/Dialect/Rock/IR/RockBase.td"
13+
include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td"
1314
include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.td"
1415
include "mlir/Dialect/Rock/IR/RockTuningParamAttrInterface.td"
1516
include "mlir/Dialect/Rock/IR/RockAccelTuningParamAttrInterface.td"
@@ -63,11 +64,13 @@ def KernelTypeConvBwdData : I32EnumAttrCase<"ConvBwdData", 1>;
6364
def KernelTypeConvBwdWeight : I32EnumAttrCase<"ConvBwdWeight", 2>;
6465
def KernelTypeGemm : I32EnumAttrCase<"Gemm", 3>;
6566
def KernelTypeAttention : I32EnumAttrCase<"Attention", 4>;
67+
def KernelTypeGemmElementwiseGemm : I32EnumAttrCase<"GemmElementwiseGemm", 5>;
6668

67-
def KernelType : Rock_I32Enum<"KernelType", "Any of the possible types of a rock kernel",
68-
[KernelTypeConv, KernelTypeConvBwdData,
69-
KernelTypeConvBwdWeight, KernelTypeGemm,
70-
KernelTypeAttention]>;
69+
def KernelType
70+
: Rock_I32Enum<"KernelType", "Any of the possible types of a rock kernel",
71+
[KernelTypeConv, KernelTypeConvBwdData,
72+
KernelTypeConvBwdWeight, KernelTypeGemm,
73+
KernelTypeAttention, KernelTypeGemmElementwiseGemm]>;
7174

7275
/// TransformType
7376
def PassThrough : I32EnumAttrCase<"PassThrough", 0>;
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- RockGemmGemmWrapperInterface.h - ops that wrap rock.attention -*- C++
2+
//-*-===//
3+
//
4+
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
5+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
// Copyright (c) 2025 Advanced Micro Devices INc.
9+
//===----------------------------------------------------------------------===//
10+
//
11+
// This file defines RockGemmGemmWrapperInterface, which abstracts attention and
12+
// gemm+gemm to allow code to operate on them generically.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#ifndef MLIR_DIALECT_ROCK_IR_ROCKGEMMGEMMWRAPPERINTERFACE_H
17+
#define MLIR_DIALECT_ROCK_IR_ROCKGEMMGEMMWRAPPERINTERFACE_H
18+
19+
#include "mlir/Dialect/Rock/IR/GemmGemmSize.h"
20+
#include "mlir/IR/OpDefinition.h"
21+
22+
#include "mlir/Dialect/Rock/IR/RockTypes.h"
23+
24+
#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h.inc"
25+
26+
#endif // MLIR_DIALECT_ROCK_IR_ROCKGEMMGEMMWRAPPERINTERFACE_H
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
//===- RockGemmGemmWrapperInterface.td - ops that wrap rock.attention
2+
//---------===//
3+
//
4+
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
5+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
// Copyright (c) 2025 Advanced Micro Devices INc.
9+
//===----------------------------------------------------------------------===//
10+
//
11+
// This file defines RockGemmGemmWrapperInterface, which abstracts attention and
12+
// gemm+gemm and friends (conv+gemm, ...) to allow code to operate on them
13+
// generically.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#ifndef ROCK_GEMM_GEMM_WRAPPER_INTERFACE
18+
#define ROCK_GEMM_GEMM_WRAPPER_INTERFACE
19+
20+
include "mlir/IR/OpBase.td"
21+
22+
def RockGemmGemmWrapperInterface : OpInterface<"RockGemmGemmWrapperInterface"> {
23+
let description = [{
24+
Interface to abstract away gemm+gemm-wrapping operators in the rock dialect,
25+
which mainly include attention and gemm+gemm and friends that can be implemented
26+
with flash attention.
27+
28+
This should include functions to get common attributes.
29+
}];
30+
let cppNamespace = "::mlir::rock";
31+
32+
let methods = [
33+
InterfaceMethod<
34+
/*desc=*/[{
35+
Return the KernelType of this op
36+
}],
37+
/*retType=*/"::mlir::rock::KernelType",
38+
/*methodName=*/"getKernelType",
39+
/*args=*/(ins),
40+
/*methodBody=*/"",
41+
/*defaultImplementation=*/""
42+
>,
43+
InterfaceMethod<
44+
/*desc=*/[{
45+
Return the arch string of this op
46+
}],
47+
/*retType=*/"StringRef",
48+
/*methodName=*/"getArch",
49+
/*args=*/(ins),
50+
/*methodBody=*/"",
51+
/*defaultImplementation=*/""
52+
>,
53+
InterfaceMethod<
54+
/*desc=*/[{
55+
Return the OpOperand that corresponds to the operand argument
56+
that corresponds to the output result of the operation.
57+
}],
58+
/*retType=*/"OpOperand *",
59+
/*methodName=*/"getOutArgument",
60+
/*args=*/(ins),
61+
/*methodBody=*/"",
62+
/*defaultImplementation=*/""
63+
>,
64+
InterfaceMethod<
65+
/*desc=*/[{
66+
Return the size of the matrix multiplication that this op will eventually
67+
perform.
68+
}],
69+
/*retType=*/"::mlir::rock::GemmGemmSize",
70+
/*methodName=*/"getGemmGemmSize",
71+
/*args=*/(ins),
72+
/*methodBody=*/"",
73+
/*defaultImplementation=*/""
74+
>,
75+
InterfaceMethod<
76+
/*desc=*/[{
77+
Return the element type of [what will become] matrix A for this operation.
78+
}],
79+
/*retType=*/"::mlir::Type",
80+
/*methodName=*/"getAType",
81+
/*args=*/(ins),
82+
/*methodBody=*/"",
83+
/*defaultImplementation=*/""
84+
>,
85+
InterfaceMethod<
86+
/*desc=*/[{
87+
Return the element type of [what will become] matrix B for this operation.
88+
}],
89+
/*retType=*/"::mlir::Type",
90+
/*methodName=*/"getBType",
91+
/*args=*/(ins),
92+
/*methodBody=*/"",
93+
/*defaultImplementation=*/""
94+
>,
95+
InterfaceMethod<
96+
/*desc=*/[{
97+
Return the element type of [what will become] matrix C for this operation.
98+
}],
99+
/*retType=*/"::mlir::Type",
100+
/*methodName=*/"getCType",
101+
/*args=*/(ins),
102+
/*methodBody=*/"",
103+
/*defaultImplementation=*/""
104+
>,
105+
InterfaceMethod<
106+
/*desc=*/[{
107+
Return the element type of [what will become] output matrix for this operation.
108+
}],
109+
/*retType=*/"::mlir::Type",
110+
/*methodName=*/"getOutType",
111+
/*args=*/(ins),
112+
/*methodBody=*/"",
113+
/*defaultImplementation=*/""
114+
>,
115+
InterfaceMethod<
116+
/*desc=*/[{
117+
Return the whether matrix A is transposed.
118+
}],
119+
/*retType=*/"bool",
120+
/*methodName=*/"getTransposedA",
121+
/*args=*/(ins),
122+
/*methodBody=*/"",
123+
/*defaultImplementation=*/""
124+
>,
125+
InterfaceMethod<
126+
/*desc=*/[{
127+
Return the whether matrix B is transposed.
128+
}],
129+
/*retType=*/"bool",
130+
/*methodName=*/"getTransposedB",
131+
/*args=*/(ins),
132+
/*methodBody=*/"",
133+
/*defaultImplementation=*/""
134+
>,
135+
InterfaceMethod<
136+
/*desc=*/[{
137+
Return the whether matrix C is transposed.
138+
}],
139+
/*retType=*/"bool",
140+
/*methodName=*/"getTransposedC",
141+
/*args=*/(ins),
142+
/*methodBody=*/"",
143+
/*defaultImplementation=*/""
144+
>,
145+
InterfaceMethod<
146+
/*desc=*/[{
147+
Return the whether output matrix is transposed.
148+
}],
149+
/*retType=*/"bool",
150+
/*methodName=*/"getTransposedOut",
151+
/*args=*/(ins),
152+
/*methodBody=*/"",
153+
/*defaultImplementation=*/""
154+
>,
155+
InterfaceMethod<
156+
/*desc=*/[{
157+
Return the features attribute of this op.
158+
}],
159+
/*retType=*/"::mlir::rock::GemmFeatures",
160+
/*methodName=*/"getGemmFeatures",
161+
/*args=*/(ins),
162+
/*methodBody=*/"",
163+
/*defaultImplementation=*/[{
164+
return $_op.getFeatures();
165+
}]
166+
>,
167+
InterfaceMethod<
168+
/*desc=*/[{
169+
Return the optional number of Compute Units the GPU provides.
170+
}],
171+
/*retType=*/"std::optional<uint32_t>",
172+
/*methodName=*/"getNumCU",
173+
/*args=*/(ins),
174+
/*methodBody=*/"",
175+
/*defaultImplementation=*/ ""
176+
>,
177+
178+
InterfaceMethod<
179+
/*desc=*/[{
180+
Set the tuning parameters attribute of the first GEMM
181+
}],
182+
/*retType=*/"void",
183+
/*methodName=*/"setGemm0ParamsAttr",
184+
/*args=*/(ins "::mlir::Attribute":$params),
185+
/*methodBody=*/"",
186+
/*defaultImplementation=*/[{
187+
$_op->setAttr($_op.getParams0AttrName(), params);
188+
}]
189+
>,
190+
InterfaceMethod<
191+
/*desc=*/[{
192+
Set the tuning parameters attribute of the second GEMM
193+
}],
194+
/*retType=*/"void",
195+
/*methodName=*/"setGemm1ParamsAttr",
196+
/*args=*/(ins "::mlir::Attribute":$params),
197+
/*methodBody=*/"",
198+
/*defaultImplementation=*/[{
199+
$_op->setAttr($_op.getParams1AttrName(), params);
200+
}]
201+
>,
202+
InterfaceMethod<
203+
/*desc=*/[{
204+
Get the tuning parameters attribute of the first GEMM
205+
}],
206+
/*retType=*/"std::optional<RockTuningParamAttrInterface>",
207+
/*methodName=*/"getGemm0Params",
208+
/*args=*/(ins),
209+
/*methodBody=*/"",
210+
/*defaultImplementation=*/[{
211+
return $_op.getParams0();
212+
}]
213+
>,
214+
InterfaceMethod<
215+
/*desc=*/[{
216+
Get the tuning parameters attribute of the second GEMM
217+
}],
218+
/*retType=*/"std::optional<RockTuningParamAttrInterface>",
219+
/*methodName=*/"getGemm1Params",
220+
/*args=*/(ins),
221+
/*methodBody=*/"",
222+
/*defaultImplementation=*/[{
223+
return $_op.getParams1();
224+
}]
225+
>,
226+
InterfaceMethod<
227+
/*desc=*/[{
228+
Return the index of the elementwise region argument that comes from the first GEMM.
229+
}],
230+
/*retType=*/"uint32_t",
231+
/*methodName=*/"getFirstGemmIndex",
232+
/*args=*/(ins),
233+
/*methodBody=*/"",
234+
/*defaultImplementation=*/ ""
235+
>,
236+
237+
// TODO: more methods here as needed
238+
];
239+
240+
let verify = [{
241+
auto concreteOp = ::mlir::cast<ConcreteOp>($_op);
242+
if ($_op->getNumResults() == 1) {
243+
if ($_op->getResult(0).getType() !=
244+
concreteOp.getOutArgument()->get().getType()) {
245+
return $_op->emitOpError("result type must match output argument type");
246+
}
247+
}
248+
return ::mlir::success();
249+
}];
250+
}
251+
252+
#endif // ROCK_GEMM_GEMM_WRAPPER_INTERFACE

0 commit comments

Comments
 (0)