Skip to content

Commit a7a804e

Browse files
committed
clang format
1 parent 07d2a1e commit a7a804e

File tree

10 files changed

+152
-162
lines changed

10 files changed

+152
-162
lines changed

mlir/include/mlir/Dialect/Rock/IR/GemmGemmSize.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file defines a utility struct, GemmGemmSize, that packages the sizes of
9+
// This file defines a utility struct, GemmGemmSize, that packages the sizes of
1010
// gemm+gemm to ensure a cleaner API.
1111
//
1212
//===----------------------------------------------------------------------===//
@@ -31,7 +31,8 @@ struct GemmGemmSize {
3131
: g(g), m(m), k(k), n(n), o(o) {}
3232

3333
bool operator==(const GemmGemmSize &other) {
34-
return (g == other.g) && (m == other.m) && (k == other.k) && (n == other.n) && (o == other.o);
34+
return (g == other.g) && (m == other.m) && (k == other.k) &&
35+
(n == other.n) && (o == other.o);
3536
}
3637
};
3738
} // end namespace rock

mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
//===- RockGemmGemmWrapperInterface.h - ops that wrap rock.attention -*- C++ -*-===//
1+
//===- RockGemmGemmWrapperInterface.h - ops that wrap rock.attention -*- C++
2+
//-*-===//
23
//
34
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
45
// Exceptions. See https://llvm.org/LICENSE.txt for license information.

mlir/include/mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
//===- RockGemmGemmWrapperInterface.td - ops that wrap rock.attention ---------===//
1+
//===- RockGemmGemmWrapperInterface.td - ops that wrap rock.attention
2+
//---------===//
23
//
3-
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4-
// See https://llvm.org/LICENSE.txt for license information.
4+
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
5+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
56
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67
//
78
// Copyright (c) 2025 Advanced Micro Devices INc.
89
//===----------------------------------------------------------------------===//
910
//
1011
// This file defines RockGemmGemmWrapperInterface, which abstracts attention and
11-
// gemm+gemm and friends (conv+gemm, ...) to allow code to operate on them generically.
12+
// gemm+gemm and friends (conv+gemm, ...) to allow code to operate on them
13+
// generically.
1214
//
1315
//===----------------------------------------------------------------------===//
1416

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -205,27 +205,24 @@ def Rock_ReduceOp :
205205
::mlir::OpOperand* getOutArgument() { return &(*this)->getOpOperand(1); }
206206
}];
207207
}
208-
def Rock_AttentionOp :
209-
Rock_Op<"attention", [DeclareOpInterfaceMethods<RockGemmGemmWrapperInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot, AttrSizedOperandSegments]>,
210-
Arguments<(ins
211-
TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries,
212-
TensorOrMemRefOf<[F32, F16, BF16, I8]>:$keys,
213-
TensorOrMemRefOf<[F32, F16, BF16]>:$values,
214-
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
215-
Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
216-
TensorOrMemRefOf<[F32, F16, BF16]>:$out,
217-
UnitAttr:$qTransposed,
218-
UnitAttr:$kTransposed,
219-
UnitAttr:$vTransposed,
220-
UnitAttr:$oTransposed,
221-
StrAttr:$arch,
222-
Rock_GemmFeaturesAttr:$features,
223-
OptionalAttr<I32Attr>:$numCU,
224-
OptionalAttr<RockTuningParamAttrInterface>:$params0,
225-
OptionalAttr<RockTuningParamAttrInterface>:$params1,
226-
I32Attr:$firstGemmIdx
227-
)>,
228-
Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
208+
def Rock_AttentionOp
209+
: Rock_Op<
210+
"attention", [DeclareOpInterfaceMethods<RockGemmGemmWrapperInterface>,
211+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
212+
RockFusionRoot, AttrSizedOperandSegments]>,
213+
Arguments<(ins TensorOrMemRefOf<[F32, F16, BF16, I8]>:$queries,
214+
TensorOrMemRefOf<[F32, F16, BF16, I8]>:$keys,
215+
TensorOrMemRefOf<[F32, F16, BF16]>:$values,
216+
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
217+
Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
218+
TensorOrMemRefOf<[F32, F16, BF16]>:$out, UnitAttr:$qTransposed,
219+
UnitAttr:$kTransposed, UnitAttr:$vTransposed, UnitAttr:$oTransposed,
220+
StrAttr:$arch, Rock_GemmFeaturesAttr:$features,
221+
OptionalAttr<I32Attr>:$numCU,
222+
OptionalAttr<RockTuningParamAttrInterface>:$params0,
223+
OptionalAttr<RockTuningParamAttrInterface>:$params1,
224+
I32Attr:$firstGemmIdx)>,
225+
Results<(outs Optional<TensorOf<[F32, F16, BF16]>>:$result)> {
229226
let summary = "Attention operation of transformer models";
230227
let description = [{
231228
Performs the operation out = SOFTMAX(queries * keys) * values.
@@ -252,27 +249,23 @@ def Rock_AttentionOp :
252249
}];
253250
}
254251

255-
def Rock_GemmElementwiseGemmOp:
256-
Rock_Op<"gemm_elementwise_gemm", [DeclareOpInterfaceMethods<RockGemmGemmWrapperInterface>, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, RockFusionRoot]>,
257-
AllElementTypesMatch<["a", "b", "c"]>,
258-
Arguments<(ins
259-
TensorOrMemRefOf<[F32]>:$a,
260-
TensorOrMemRefOf<[F32]>:$b,
261-
TensorOrMemRefOf<[F32]>:$c,
262-
Variadic<AnyTensorOrMemRef>:$elemwiseInputs,
263-
TensorOrMemRefOf<[F32]>:$out,
264-
UnitAttr:$aTransposed,
265-
UnitAttr:$bTransposed,
266-
UnitAttr:$cTransposed,
267-
UnitAttr:$oTransposed,
268-
StrAttr:$arch,
269-
Rock_GemmFeaturesAttr:$features,
270-
OptionalAttr<I32Attr>:$numCU,
271-
OptionalAttr<RockTuningParamAttrInterface>:$params0,
272-
OptionalAttr<RockTuningParamAttrInterface>:$params1,
273-
I32Attr:$firstGemmIdx
274-
)>,
275-
Results<(outs Optional<TensorOf<[F32]>>:$result)> {
252+
def Rock_GemmElementwiseGemmOp
253+
: Rock_Op<"gemm_elementwise_gemm",
254+
[DeclareOpInterfaceMethods<RockGemmGemmWrapperInterface>,
255+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
256+
RockFusionRoot]>,
257+
AllElementTypesMatch<["a", "b", "c"]>,
258+
Arguments<(ins TensorOrMemRefOf<[F32]>:$a, TensorOrMemRefOf<[F32]>:$b,
259+
TensorOrMemRefOf<[F32]>:$c,
260+
Variadic<AnyTensorOrMemRef>:$elemwiseInputs,
261+
TensorOrMemRefOf<[F32]>:$out, UnitAttr:$aTransposed,
262+
UnitAttr:$bTransposed, UnitAttr:$cTransposed, UnitAttr:$oTransposed,
263+
StrAttr:$arch, Rock_GemmFeaturesAttr:$features,
264+
OptionalAttr<I32Attr>:$numCU,
265+
OptionalAttr<RockTuningParamAttrInterface>:$params0,
266+
OptionalAttr<RockTuningParamAttrInterface>:$params1,
267+
I32Attr:$firstGemmIdx)>,
268+
Results<(outs Optional<TensorOf<[F32]>>:$result)> {
276269
let summary = "GEMM-elementwise-GEMM operation";
277270
let description = [{
278271
Performs the operation out = (a * b) * c.

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 44 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Rock/IR/Rock.h"
10-
#include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h"
1110
#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
11+
#include "mlir/Dialect/Rock/IR/RockGemmWrapperInterface.h"
1212
#include "mlir/Dialect/Rock/IR/RockTypes.h"
1313
#include "mlir/Dialect/Rock/utility/math.h"
1414

@@ -2081,48 +2081,38 @@ LogicalResult BlockwiseFillOp::verify() {
20812081
//===-----------------------------------------------------===//
20822082

20832083
OpOperand *GemmElementwiseGemmOp::getOutArgument() {
2084-
return &(*this)->getOpOperand(getNumOperands()-1);
2084+
return &(*this)->getOpOperand(getNumOperands() - 1);
20852085
}
20862086

2087-
Type GemmElementwiseGemmOp::getOutType() {
2088-
return getOut().getType();
2089-
}
2087+
Type GemmElementwiseGemmOp::getOutType() { return getOut().getType(); }
20902088

2091-
Type GemmElementwiseGemmOp::getAType() {
2092-
return getA().getType();
2093-
}
2089+
Type GemmElementwiseGemmOp::getAType() { return getA().getType(); }
20942090

2095-
Type GemmElementwiseGemmOp::getBType() {
2096-
return getB().getType();
2097-
}
2091+
Type GemmElementwiseGemmOp::getBType() { return getB().getType(); }
20982092

2099-
Type GemmElementwiseGemmOp::getCType() {
2100-
return getC().getType();
2101-
}
2093+
Type GemmElementwiseGemmOp::getCType() { return getC().getType(); }
21022094

2103-
bool GemmElementwiseGemmOp::getTransposedA() {
2104-
return getATransposed();
2105-
}
2095+
bool GemmElementwiseGemmOp::getTransposedA() { return getATransposed(); }
21062096

2107-
bool GemmElementwiseGemmOp::getTransposedB() {
2108-
return getBTransposed();
2109-
}
2097+
bool GemmElementwiseGemmOp::getTransposedB() { return getBTransposed(); }
21102098

2111-
bool GemmElementwiseGemmOp::getTransposedC() {
2112-
return getCTransposed();
2113-
}
2099+
bool GemmElementwiseGemmOp::getTransposedC() { return getCTransposed(); }
21142100

2115-
bool GemmElementwiseGemmOp::getTransposedOut() {
2116-
return getOTransposed();
2117-
}
2101+
bool GemmElementwiseGemmOp::getTransposedOut() { return getOTransposed(); }
21182102

2119-
KernelType GemmElementwiseGemmOp::getKernelType() { return KernelType::GemmElementwiseGemm; }
2103+
KernelType GemmElementwiseGemmOp::getKernelType() {
2104+
return KernelType::GemmElementwiseGemm;
2105+
}
21202106

2121-
uint32_t GemmElementwiseGemmOp::getFirstGemmIndex() { return getFirstGemmIdx(); }
2107+
uint32_t GemmElementwiseGemmOp::getFirstGemmIndex() {
2108+
return getFirstGemmIdx();
2109+
}
21222110

21232111
GemmGemmSize GemmElementwiseGemmOp::getGemmGemmSize() {
2124-
ShapedType typeA = getA().getType(), typeB = getB().getType(), typeC = getC().getType();
2125-
ArrayRef<int64_t> dimsA = typeA.getShape(), dimsB = typeB.getShape(), dimsC = typeC.getShape();
2112+
ShapedType typeA = getA().getType(), typeB = getB().getType(),
2113+
typeC = getC().getType();
2114+
ArrayRef<int64_t> dimsA = typeA.getShape(), dimsB = typeB.getShape(),
2115+
dimsC = typeC.getShape();
21262116
int64_t offsetA = dimsA.size() == 2 ? 0 : 1,
21272117
offsetB = dimsB.size() == 2 ? 0 : 1,
21282118
offsetC = dimsC.size() == 2 ? 0 : 1;
@@ -2134,25 +2124,28 @@ GemmGemmSize GemmElementwiseGemmOp::getGemmGemmSize() {
21342124
return GemmGemmSize(g, m, k, n, o);
21352125
}
21362126

2137-
static LogicalResult verifyAttentionOp(RockGemmGemmWrapperInterface op,
2127+
static LogicalResult verifyAttentionOp(RockGemmGemmWrapperInterface op,
21382128
Value currentSeqLen) {
21392129
ShapedType qType = cast<ShapedType>(op.getAType());
21402130
int64_t qBatchDim = qType.getShape().size() == 3 ? qType.getShape()[0] : 1;
21412131
ArrayRef<int64_t> qLastDims = qType.getShape().slice(qType.getRank() - 2);
2142-
auto [queryM, queryK] = op.getTransposedA() ? std::tuple{qLastDims[1], qLastDims[0]}
2143-
: std::tuple{qLastDims[0], qLastDims[1]};
2132+
auto [queryM, queryK] = op.getTransposedA()
2133+
? std::tuple{qLastDims[1], qLastDims[0]}
2134+
: std::tuple{qLastDims[0], qLastDims[1]};
21442135

21452136
ShapedType kType = cast<ShapedType>(op.getBType());
21462137
int64_t kBatchDim = kType.getShape().size() == 3 ? kType.getShape()[0] : 1;
21472138
ArrayRef<int64_t> kLastDims = kType.getShape().slice(kType.getRank() - 2);
2148-
auto [keyK, keyN] = op.getTransposedB() ? std::tuple{kLastDims[1], kLastDims[0]}
2149-
: std::tuple{kLastDims[0], kLastDims[1]};
2139+
auto [keyK, keyN] = op.getTransposedB()
2140+
? std::tuple{kLastDims[1], kLastDims[0]}
2141+
: std::tuple{kLastDims[0], kLastDims[1]};
21502142

21512143
ShapedType vType = cast<ShapedType>(op.getCType());
21522144
int64_t vBatchDim = vType.getShape().size() == 3 ? vType.getShape()[0] : 1;
21532145
ArrayRef<int64_t> vLastDims = vType.getShape().slice(vType.getRank() - 2);
2154-
auto [valueK, valueN] = op.getTransposedC() ? std::tuple{vLastDims[1], vLastDims[0]}
2155-
: std::tuple{vLastDims[0], vLastDims[1]};
2146+
auto [valueK, valueN] = op.getTransposedC()
2147+
? std::tuple{vLastDims[1], vLastDims[0]}
2148+
: std::tuple{vLastDims[0], vLastDims[1]};
21562149

21572150
if (qBatchDim != kBatchDim || kBatchDim != vBatchDim) {
21582151
return op.emitError("Batch dimensions do not match");
@@ -2171,7 +2164,7 @@ static LogicalResult verifyAttentionOp(RockGemmGemmWrapperInterface op,
21712164
ArrayRef<int64_t> oLastDims = oType.getShape().slice(oType.getRank() - 2);
21722165
auto [outputSeqLen, outputHeadDim] =
21732166
op.getTransposedOut() ? std::tuple{oLastDims[1], oLastDims[0]}
2174-
: std::tuple{oLastDims[0], oLastDims[1]};
2167+
: std::tuple{oLastDims[0], oLastDims[1]};
21752168

21762169
if (qType.getShape().size() != oType.getShape().size()) {
21772170
return op.emitError("Number of dimensions do not match (Q and Output)");
@@ -2223,48 +2216,34 @@ void GemmElementwiseGemmOp::getEffects(
22232216
//===-----------------------------------------------------===//
22242217

22252218
OpOperand *AttentionOp::getOutArgument() {
2226-
return &(*this)->getOpOperand(getNumOperands()-1);
2219+
return &(*this)->getOpOperand(getNumOperands() - 1);
22272220
}
22282221

2229-
Type AttentionOp::getOutType() {
2230-
return getOut().getType();
2231-
}
2222+
Type AttentionOp::getOutType() { return getOut().getType(); }
22322223

2233-
Type AttentionOp::getAType() {
2234-
return getQueries().getType();
2235-
}
2224+
Type AttentionOp::getAType() { return getQueries().getType(); }
22362225

2237-
Type AttentionOp::getBType() {
2238-
return getKeys().getType();
2239-
}
2226+
Type AttentionOp::getBType() { return getKeys().getType(); }
22402227

2241-
Type AttentionOp::getCType() {
2242-
return getValues().getType();
2243-
}
2228+
Type AttentionOp::getCType() { return getValues().getType(); }
22442229

2245-
bool AttentionOp::getTransposedA() {
2246-
return getQTransposed();
2247-
}
2230+
bool AttentionOp::getTransposedA() { return getQTransposed(); }
22482231

2249-
bool AttentionOp::getTransposedB() {
2250-
return getKTransposed();
2251-
}
2232+
bool AttentionOp::getTransposedB() { return getKTransposed(); }
22522233

2253-
bool AttentionOp::getTransposedC() {
2254-
return getVTransposed();
2255-
}
2234+
bool AttentionOp::getTransposedC() { return getVTransposed(); }
22562235

2257-
bool AttentionOp::getTransposedOut() {
2258-
return getOTransposed();
2259-
}
2236+
bool AttentionOp::getTransposedOut() { return getOTransposed(); }
22602237

22612238
KernelType AttentionOp::getKernelType() { return KernelType::Attention; }
22622239

22632240
uint32_t AttentionOp::getFirstGemmIndex() { return getFirstGemmIdx(); }
22642241

22652242
GemmGemmSize AttentionOp::getGemmGemmSize() {
2266-
ShapedType typeA = getQueries().getType(), typeB = getKeys().getType(), typeC = getValues().getType();
2267-
ArrayRef<int64_t> dimsA = typeA.getShape(), dimsB = typeB.getShape(), dimsC = typeC.getShape();
2243+
ShapedType typeA = getQueries().getType(), typeB = getKeys().getType(),
2244+
typeC = getValues().getType();
2245+
ArrayRef<int64_t> dimsA = typeA.getShape(), dimsB = typeB.getShape(),
2246+
dimsC = typeC.getShape();
22682247
int64_t offsetA = dimsA.size() == 2 ? 0 : 1,
22692248
offsetB = dimsB.size() == 2 ? 0 : 1,
22702249
offsetC = dimsC.size() == 2 ? 0 : 1;

mlir/lib/Dialect/Rock/IR/RockGemmGemmWrapperInterface.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
//===- RockGemmGemmWrapperInterface.cpp - ops that wrap rock.attention -------===//
1+
//===- RockGemmGemmWrapperInterface.cpp - ops that wrap rock.attention
2+
//-------===//
23
//
34
// Part of the rocMLIR Project, under the Apache License v2.0 with LLVM
45
// Exceptions. See https://llvm.org/LICENSE.txt for license information.

mlir/lib/Dialect/Rock/Transforms/AffixTuningParameters.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ void AffixTuningParameters::runOnOperation() {
5454

5555
func.walk(
5656
[&](RockGemmWrapperInterface op) { affixTuningParametersImpl(op); });
57-
func.walk([&](RockGemmGemmWrapperInterface op) { affixTuningParametersImpl(op); });
57+
func.walk(
58+
[&](RockGemmGemmWrapperInterface op) { affixTuningParametersImpl(op); });
5859
func.walk([&](ReduceOp op) {
5960
func::FuncOp funcOp = getOperation();
6061
if (!funcOp->hasAttr("block_size")) {
@@ -243,13 +244,14 @@ deriveGemm1TuningParams(OpBuilder &builder, RockGemmGemmWrapperInterface op,
243244
gemm0TuningParams.getOutputSwizzle(), gemm0TuningParams.getForceUnroll());
244245
}
245246

246-
void AffixTuningParameters::affixTuningParametersImpl(RockGemmGemmWrapperInterface op) {
247+
void AffixTuningParameters::affixTuningParametersImpl(
248+
RockGemmGemmWrapperInterface op) {
247249
OpBuilder builder(op.getContext());
248250
bool isAccel = rock::isAccel(op.getGemmFeatures());
249251
if (!isAccel) {
250252
op.emitError("Currently, attention/gemm+gemm op is only "
251-
"supported on GPUs "
252-
"with matrix accelerator extentions");
253+
"supported on GPUs "
254+
"with matrix accelerator extentions");
253255
return signalPassFailure();
254256
}
255257
Attribute params0 = op.getGemm0Params().value_or(nullptr);
@@ -303,12 +305,14 @@ void AffixTuningParameters::affixTuningParametersImpl(RockGemmGemmWrapperInterfa
303305
LLVM_DEBUG(llvm::dbgs() << "accelParams1=" << accelParams1 << "\n");
304306
LogicalResult isValidBlockwiseGemm0 =
305307
populateParamsAccelPtr->isValidBlockwiseGemm(
306-
accelParams0, cast<MemRefType>(op.getAType()).getElementType(), cast<MemRefType>(op.getBType()).getElementType(), op.getArch(),
308+
accelParams0, cast<MemRefType>(op.getAType()).getElementType(),
309+
cast<MemRefType>(op.getBType()).getElementType(), op.getArch(),
307310
/*enableBlockSizeUpperLimit=*/false,
308311
/*enableDPerWaveFiltering=*/false);
309312
LogicalResult isValidBlockwiseGemm1 =
310313
populateParamsAccelPtr->isValidBlockwiseGemm(
311-
accelParams1, cast<MemRefType>(op.getCType()).getElementType(), cast<MemRefType>(op.getCType()).getElementType(), op.getArch(),
314+
accelParams1, cast<MemRefType>(op.getCType()).getElementType(),
315+
cast<MemRefType>(op.getCType()).getElementType(), op.getArch(),
312316
/*enableBlockSizeUpperLimit=*/false,
313317
/*enableDPerWaveFiltering=*/false);
314318
if (isValidBlockwiseGemm0.failed() || isValidBlockwiseGemm1.failed()) {

0 commit comments

Comments
 (0)