Skip to content

Commit ec1cdfe

Browse files
authored
Attention: return LSE (log-sum-exp) (#1882)
Attention can return LSE (log-sum-exp) This PR introduces the ability to return LSE (log-sum-exp), note that migraphx requires L to be log2.
1 parent 38fc327 commit ec1cdfe

30 files changed

+497
-106
lines changed

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def Rock_GemmOp :
165165
transposed. For example, if `aTransposed` is set, then the argument A should be
166166
a [G] x K x M memory.
167167

168-
Those creating a `rock.gemm` must specify the GPU architecture being targetted
168+
Those creating a `rock.gemm` must specify the GPU architecture being targeted
169169
and the number of compute units (numCu) available. The parameters
170170
`derivedBlockSize`, `gridSize`, and `params` are optional as they can be inferred by
171171
a tuning process or a heuristic, but they must be set before the `gemm` is
@@ -215,10 +215,11 @@ def Rock_AttentionOp
215215
TensorOrMemRefOf<[F32, F16, BF16]>:$values,
216216
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
217217
Optional<TensorOrMemRefOf<[I32]>>:$currentSeqLen,
218-
TensorOrMemRefOf<[F32, F16, BF16]>:$out, UnitAttr:$qTransposed,
219-
UnitAttr:$kTransposed, UnitAttr:$vTransposed, UnitAttr:$oTransposed,
220-
UnitAttr:$causal, StrAttr:$arch, Rock_GemmFeaturesAttr:$features,
221-
OptionalAttr<I32Attr>:$numCU,
218+
TensorOrMemRefOf<[F32, F16, BF16]>:$out,
219+
Optional<TensorOrMemRefOf<[F32, F16, BF16]>>:$lse,
220+
UnitAttr:$qTransposed, UnitAttr:$kTransposed, UnitAttr:$vTransposed,
221+
UnitAttr:$oTransposed, UnitAttr:$causal, StrAttr:$arch,
222+
Rock_GemmFeaturesAttr:$features, OptionalAttr<I32Attr>:$numCU,
222223
OptionalAttr<RockTuningParamAttrInterface>:$params0,
223224
OptionalAttr<RockTuningParamAttrInterface>:$params1,
224225
I32Attr:$firstGemmIdx)>,
@@ -240,7 +241,9 @@ def Rock_AttentionOp
240241

241242
If causal is enabled, we implement causal masking.
242243

243-
Those creating a `rock.attention` must specify the GPU architecture being targetted
244+
LSE (log-sum-exp) is an optional output typically used for flash decoding.
245+
246+
Those creating a `rock.attention` must specify the GPU architecture being targeted
244247
and the number of compute units (numCu) available. The parameters
245248
`gridSize`, and `blockSize` are optional as they can be inferred by
246249
a tuning process or a heuristic, but they must be set before the `attention` is
@@ -255,6 +258,7 @@ def Rock_AttentionOp
255258
` ` `qk` `=` (`tr` $qTransposed^)? $queries `*` (`tr` $kTransposed^)? $keys `:` type($queries) `,` type($keys) `\n`
256259
(`currentSeqLen` `=` `(` $currentSeqLen^ `:` type($currentSeqLen) `)` `\n`)?
257260
(`causal` `\n` $causal^)?
261+
(`lse` `=` $lse^ `:` type($lse) `\n`)?
258262
(`qk` `=` `elementwise` (`otherIns` `(` $preSoftmaxElemWiseInputs^ `:` type($preSoftmaxElemWiseInputs) `)`)? $preSoftmaxBody^ `\n`)?
259263
(`tr` $oTransposed^)? $out `=` `softmax` `(` `qk` `)` `*` (`tr` $vTransposed^)? $values `:` type($values) `->` type($out) `\n`
260264
`}` attr-dict (`->` type($result)^)?
@@ -294,7 +298,7 @@ def Rock_GemmElementwiseGemmOp
294298
transposed. For example, if `aTransposed` is set, then the argument `a` should be
295299
a [G] x K x M memory.
296300

297-
Those creating a `rock.gemm_elementwise_gemm` must specify the GPU architecture being targetted
301+
Those creating a `rock.gemm_elementwise_gemm` must specify the GPU architecture being targeted
298302
and the number of compute units (numCu) available. The parameters
299303
`gridSize`, and `blockSize` are optional as they can be inferred by
300304
a tuning process or a heuristic, but they must be set before the `gemm_elementwise_gemm` is
@@ -346,7 +350,7 @@ def Rock_ConvElementwiseGemmOp
346350
transposed. For example, if `cTransposed` is set, then the argument `c` should be
347351
a [G] x O x M memory.
348352

349-
Those creating a `rock.conv_elementwise_gemm` must specify the GPU architecture being targetted
353+
Those creating a `rock.conv_elementwise_gemm` must specify the GPU architecture being targeted
350354
and the number of compute units (numCu) available. The parameters
351355
`gridSize`, and `blockSize` are optional as they can be inferred by
352356
a tuning process or a heuristic, but they must be set before the `conv_elementwise_gemm` is
@@ -551,7 +555,8 @@ def Rock_GridwiseAttentionAccelOp
551555
MemRefRankOf<[F32, F16, BF16], [3]>:$values,
552556
Variadic<AnyTensorOrMemRef>:$preSoftmaxElemWiseInputs,
553557
Optional<MemRefRankOf<[I32], [1]>>:$currentSeqLen,
554-
MemRefRankOf<[F32, F16, BF16], [3]>:$out, UnitAttr:$causal,
558+
MemRefRankOf<[F32, F16, BF16], [3]>:$out,
559+
Optional<MemRefRankOf<[F32, F16, BF16], [2]>>:$lse, UnitAttr:$causal,
555560
StrAttr:$arch, Rock_GemmFeaturesAttr:$features, I32Attr:$blockSize,
556561
I32Attr:$gridSize, UnitAttr:$disableQBypassLDS,
557562
OptionalAttr<IndexAttr>:$prePadG0M,

mlir/include/mlir/Dialect/Rock/utility/loweringUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ Type vectorTypeOrSelf(Type elementType, int64_t len);
135135
Value padMatrix(Value matrix, OpBuilder &b, Location loc, StringRef firstDim,
136136
int64_t firstDimPad, StringRef secondDim, int64_t secondDimPad);
137137

138+
// Apply padding to a vector in its `firstDim` if applicable.
139+
Value padVector(Value vector, OpBuilder &b, Location loc, StringRef firstDim,
140+
int64_t firstDimPad);
141+
138142
/// Normalize the argument into the form requested.
139143
/// If a group dimension is not present, add one.
140144
/// If doTranspose is true, meaning the user's transpose requests don't match

mlir/lib/Conversion/TosaToRock/TosaToRock.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,6 +1693,7 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
16931693
rock::AttentionOp attnOp = rewriter.create<rock::AttentionOp>(
16941694
loc, outputType, firstMatMulOp.getA(), firstMatMulOp.getB(), op.getB(),
16951695
elementwiseOtherArgs, currentSeqLen, output,
1696+
/*lse=*/nullptr,
16961697
/*qTransposed=*/nullptr,
16971698
/*kTransposed=*/nullptr,
16981699
/*vTransposed=*/nullptr,

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1869,7 +1869,11 @@ LogicalResult GridwiseAttentionAccelOp::verify() {
18691869
int64_t gemm0kpack = gemm0TuningParams.getKpack();
18701870
int64_t gemm0NPerBlock = gemm0TuningParams.getNPerBlock();
18711871
if (gemm0NPerBlock % gemm0kpack != 0) {
1872-
return emitError("NPerBlock should be divisble by kpack.");
1872+
return emitError("NPerBlock should be divisible by kpack.");
1873+
}
1874+
1875+
if (!getEnableSoftmax() && getLse()) {
1876+
return emitError("LSE only works for attention.");
18731877
}
18741878

18751879
int64_t linalgOpCount = 0;
@@ -2126,7 +2130,7 @@ GemmGemmSize GemmElementwiseGemmOp::getGemmGemmSize() {
21262130
}
21272131

21282132
static LogicalResult verifyGemmPlusGemmLikeOp(RockGemmGemmWrapperInterface op,
2129-
Value currentSeqLen) {
2133+
Value currentSeqLen, Value lse) {
21302134
ShapedType qType = cast<ShapedType>(op.getAType());
21312135
int64_t qBatchDim = qType.getShape().size() == 3 ? qType.getShape()[0] : 1;
21322136
ArrayRef<int64_t> qLastDims = qType.getShape().slice(qType.getRank() - 2);
@@ -2191,11 +2195,26 @@ static LogicalResult verifyGemmPlusGemmLikeOp(RockGemmGemmWrapperInterface op,
21912195
"Batch dimensions do not match (currentSeqLen and Output)");
21922196
}
21932197
}
2198+
2199+
// check LSE (log-sum-exp)
2200+
if (lse) {
2201+
ShapedType lseType = cast<ShapedType>(lse.getType());
2202+
if (lseType.getShape().size() != 2) {
2203+
return op.emitError("Number of dimensions is not two (LSE)");
2204+
}
2205+
if (lseType.getShape()[0] != oBatchDim) {
2206+
return op.emitError("Batch dimensions do not match (LSE and Output)");
2207+
}
2208+
if (lseType.getShape()[1] != queryM) {
2209+
return op.emitError("SeqLenQ dimensions do not match (LSE and Q)");
2210+
}
2211+
}
21942212
return success();
21952213
}
21962214

21972215
LogicalResult GemmElementwiseGemmOp::verify() {
2198-
return verifyGemmPlusGemmLikeOp(*this, /*currentSeqLen=*/nullptr);
2216+
return verifyGemmPlusGemmLikeOp(*this, /*currentSeqLen=*/nullptr,
2217+
/*lse=*/nullptr);
21992218
}
22002219

22012220
void GemmElementwiseGemmOp::getEffects(
@@ -2290,7 +2309,8 @@ GemmGemmSize ConvElementwiseGemmOp::getGemmGemmSize() {
22902309
}
22912310

22922311
LogicalResult ConvElementwiseGemmOp::verify() {
2293-
return verifyGemmPlusGemmLikeOp(*this, /*currentSeqLen=*/nullptr);
2312+
return verifyGemmPlusGemmLikeOp(*this, /*currentSeqLen=*/nullptr,
2313+
/*lse=*/nullptr);
22942314
}
22952315

22962316
void ConvElementwiseGemmOp::getEffects(
@@ -2354,7 +2374,7 @@ GemmGemmSize AttentionOp::getGemmGemmSize() {
23542374
}
23552375

23562376
LogicalResult AttentionOp::verify() {
2357-
return verifyGemmPlusGemmLikeOp(*this, getCurrentSeqLen());
2377+
return verifyGemmPlusGemmLikeOp(*this, getCurrentSeqLen(), getLse());
23582378
}
23592379

23602380
void AttentionOp::getEffects(

mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -92,19 +92,12 @@ struct GemmElementwiseGemmRewritePattern
9292
LogicalResult matchAndRewrite(GemmElementwiseGemmOp op,
9393
GemmElementwiseGemmOpAdaptor adaptor,
9494
ConversionPatternRewriter &rw) const override;
95-
96-
LogicalResult computeGridSize(ConversionPatternRewriter &rw,
97-
GemmElementwiseGemmOp op, Value a, Value b,
98-
Value c) const;
9995
};
10096

10197
struct AttentionRewritePattern : public OpConversionPattern<AttentionOp> {
10298
using OpConversionPattern<AttentionOp>::OpConversionPattern;
10399
LogicalResult matchAndRewrite(AttentionOp op, AttentionOpAdaptor adaptor,
104100
ConversionPatternRewriter &rw) const override;
105-
106-
LogicalResult computeGridSize(ConversionPatternRewriter &rw, AttentionOp op,
107-
Value queries, Value keys, Value values) const;
108101
};
109102

110103
template <typename Op>
@@ -139,7 +132,7 @@ computeGridSizeAttentionGemmElmtGemm(ConversionPatternRewriter &rw, Op op,
139132
static LogicalResult
140133
commonAttentionGemmElmtGemm(ConversionPatternRewriter &rw,
141134
RockGemmGemmWrapperInterface op, Value a, Value b,
142-
Value c, Value out, Value currentSeqLen,
135+
Value c, Value out, Value lse, Value currentSeqLen,
143136
UnitAttr causal, ValueRange elementwiseInputs,
144137
Region &preSecondOpRegion, bool enableSoftmax) {
145138
Location loc = op->getLoc();
@@ -150,17 +143,17 @@ commonAttentionGemmElmtGemm(ConversionPatternRewriter &rw,
150143
bool isAccel = rock::isAccel(op.getGemmFeatures());
151144
if (!isAccel) {
152145
return op.emitError("Currently, op is only supported on GPUs "
153-
"with matrix accelerator extentions");
146+
"with matrix accelerator extensions");
154147
}
155148
if (!op.getGemm0Params().has_value()) {
156149
return op.emitError("gemm0 params is missing and it should've been "
157-
"assigned by affix-tuing-params");
150+
"assigned by affix-tuning-params");
158151
}
159152
RockAccelTuningParamAttrInterface params0 =
160153
cast<RockAccelTuningParamAttrInterface>(op.getGemm0Params().value());
161154
if (!op.getGemm1Params().has_value()) {
162155
return op.emitError("gemm1 params is missing and it should've been "
163-
"assigned by affix-tuing-params");
156+
"assigned by affix-tuning-params");
164157
}
165158
RockAccelTuningParamAttrInterface params1 =
166159
cast<RockAccelTuningParamAttrInterface>(op.getGemm1Params().value());
@@ -177,6 +170,7 @@ commonAttentionGemmElmtGemm(ConversionPatternRewriter &rw,
177170
ArrayRef<int64_t> aShape = cast<MemRefType>(a.getType()).getShape();
178171
ArrayRef<int64_t> bShape = cast<MemRefType>(b.getType()).getShape();
179172
ArrayRef<int64_t> cShape = cast<MemRefType>(c.getType()).getShape();
173+
assert(cShape[1] == bShape[2]);
180174
GemmSize gemm0Size(/*g=*/aShape[0], /*m=*/bShape[2],
181175
/*k=*/aShape[1],
182176
/*n=*/aShape[2]);
@@ -200,6 +194,8 @@ commonAttentionGemmElmtGemm(ConversionPatternRewriter &rw,
200194
// fusions legit. So the extra pad needs to be swapped and applied.
201195
out = padMatrix(out, rw, loc, "gemm1N", gemm1ExtraPad.n, "gemm1M",
202196
gemm1ExtraPad.m);
197+
if (lse)
198+
lse = padVector(lse, rw, loc, "gemm1N", gemm1ExtraPad.n);
203199

204200
if (failed(computeGridSizeAttentionGemmElmtGemm(rw, op, a, b, c))) {
205201
return op.emitError("failed to compute the grid size of "
@@ -218,7 +214,7 @@ commonAttentionGemmElmtGemm(ConversionPatternRewriter &rw,
218214
prePadG0NAttr = rw.getIndexAttr(gemm0Size.n);
219215
}
220216
auto newOp = rw.create<GridwiseAttentionAccelOp>(
221-
loc, a, b, c, elementwiseInputs, currentSeqLen, out, causal,
217+
loc, a, b, c, elementwiseInputs, currentSeqLen, out, lse, causal,
222218
rw.getStringAttr(op.getArch()),
223219
rw.getAttr<rock::GemmFeaturesAttr>(op.getGemmFeatures()), blockSizeAttr,
224220
gridSizeAttr,
@@ -585,34 +581,23 @@ AttentionRewritePattern::matchAndRewrite(AttentionOp op,
585581
ConversionPatternRewriter &rw) const {
586582
return commonAttentionGemmElmtGemm(
587583
rw, op, adaptor.getQueries(), adaptor.getKeys(), adaptor.getValues(),
588-
adaptor.getOut(), adaptor.getCurrentSeqLen(), adaptor.getCausalAttr(),
589-
adaptor.getPreSoftmaxElemWiseInputs(), op.getPreSoftmaxBody(),
584+
adaptor.getOut(), adaptor.getLse(), adaptor.getCurrentSeqLen(),
585+
adaptor.getCausalAttr(), adaptor.getPreSoftmaxElemWiseInputs(),
586+
op.getPreSoftmaxBody(),
590587
/*enableSoftmax=*/true);
591588
}
592589

593-
LogicalResult
594-
AttentionRewritePattern::computeGridSize(ConversionPatternRewriter &rw,
595-
AttentionOp op, Value queries,
596-
Value keys, Value values) const {
597-
return computeGridSizeAttentionGemmElmtGemm(rw, op, queries, keys, values);
598-
}
599-
600590
LogicalResult GemmElementwiseGemmRewritePattern::matchAndRewrite(
601591
GemmElementwiseGemmOp op, GemmElementwiseGemmOpAdaptor adaptor,
602592
ConversionPatternRewriter &rw) const {
603593
return commonAttentionGemmElmtGemm(
604594
rw, op, adaptor.getA(), adaptor.getB(), adaptor.getC(), adaptor.getOut(),
595+
/*lse=*/nullptr,
605596
/*currentSeqLen=*/nullptr, /*causal=*/nullptr,
606597
adaptor.getElemwiseInputs(), op.getPreSecondGemmBody(),
607598
/*enableSoftmax=*/false);
608599
}
609600

610-
LogicalResult GemmElementwiseGemmRewritePattern::computeGridSize(
611-
ConversionPatternRewriter &rw, GemmElementwiseGemmOp op, Value a, Value b,
612-
Value c) const {
613-
return computeGridSizeAttentionGemmElmtGemm(rw, op, a, b, c);
614-
}
615-
616601
void RockGemmToGridwisePass::runOnOperation() {
617602
MLIRContext *ctx = &getContext();
618603
ConversionTarget target(*ctx);

mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,14 @@ GridCoordinates rock::layout::makeGroupedGridLayout(PatternRewriter &b,
7070
// be slowest changing in the grid.
7171
int64_t numChiplets = rock::lookupArchInfo(arch).maxNumXCC;
7272
if (numChiplets > 1) {
73-
// It was emphircally found that two chiplets as a group
73+
// It was empirically found that two chiplets as a group
7474
// computing a spatial mxn tile has better locality throughout.
7575
int64_t numChipletsPerGroup = std::ceil(numChiplets / 2);
7676
int64_t gridSize = info.gBlocks * info.mBlocks * info.nBlocks;
7777
bid = rearrangeWorkgroupsForXCC(loc, b, bid, gridSize, numChipletsPerGroup);
7878
}
7979

80-
// Heurisitc to compute groupSize
80+
// Heuristic to compute groupSize
8181
// This also covers the cases where the output width is larger
8282
// than the input width
8383
int64_t bitWidthIn = info.inputType.getIntOrFloatBitWidth();
@@ -137,7 +137,7 @@ GridCoordinates rock::layout::makeGxNGridLayout(PatternRewriter &b,
137137
// be slowest changing in the grid.
138138
int64_t numChiplets = rock::lookupArchInfo(arch).maxNumXCC;
139139
if (numChiplets > 1) {
140-
// It was emphircally found that two chiplets as a group
140+
// It was empirically found that two chiplets as a group
141141
// computing a spatial mxn tile has better locality throughout.
142142
int64_t numChipletsPerGroup = std::ceil(numChiplets / 2);
143143
bid = rearrangeWorkgroupsForXCC(loc, b, bid, gridSize, numChipletsPerGroup);

mlir/lib/Dialect/Rock/Transforms/GridLayoutEmitter.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ struct GridLayoutInfo {
5050
Type outputType;
5151
};
5252

53-
/// This function emits the right triplet of <group,block_m,block_n> identifers,
54-
/// given a flat blockId. This has been adapted from:
53+
/// This function emits the right triplet of <group,block_m,block_n>
54+
/// identifiers, given a flat blockId. This has been adapted from:
5555
/// https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#sphx-glr-getting-started-tutorials-03-matrix-multiplication-py
5656
///
5757
GridCoordinates makeGroupedGridLayout(PatternRewriter &b, Location loc,

0 commit comments

Comments
 (0)