4444#include " mlir/Support/LogicalResult.h"
4545#include " mlir/Transforms/DialectConversion.h"
4646
47+ #include " llvm/ADT/STLExtras.h"
4748#include " llvm/Support/Debug.h"
49+ #include " llvm/Support/Errc.h"
4850#include " llvm/Support/LogicalResult.h"
4951#include < algorithm>
5052#include < memory>
@@ -108,6 +110,154 @@ struct AttentionRewritePattern : public OpConversionPattern<AttentionOp> {
108110 ConversionPatternRewriter &rw) const override ;
109111};
110112
113+ // Move num_heads dimension to sequence length dimension. This is useful for the
114+ // decoding phase, when batch=1, seq_len_q = 1 and GQA (example: num_heads_q=64,
115+ // num_heads_kv=8), we can move numRepeat=num_heads_q/num_heads_kv = 8, to the
116+ // seq_len_q dimension and use the tile size better (otherwise seq_len_q=1 and
117+ // it will get padded to 32). This reduces the amount of workgroups by
118+ // numRepeat. However, typically decoding phase will use split_kv anyway to
119+ // increase the number of workgroups.
120+ static Value moveNumHeadsToSeqLenQ (OpBuilder builder, Location loc,
121+ Value inputTensor, int64_t numRepeats) {
122+ ArrayRef<int64_t > inpShape =
123+ cast<ShapedType>(inputTensor.getType ()).getShape ();
124+
125+ assert (inpShape.size () == 3 && " input must be 3D" );
126+ assert (inpShape[0 ] % numRepeats == 0 &&
127+ " gemmG must be divisible by numRepeats" );
128+
129+ int64_t newGemmG = inpShape[0 ] / numRepeats;
130+ SmallVector<StringRef> startNames = {" gemmG" , " headDim" , " seqLen" };
131+
132+ // (gemmG, headDim, seqLen) -> (gemmG / numRepeats, headDim, seqLen,
133+ // numRepeats)
134+ rock::BottomUpTMBuilder unmerge (builder, startNames, inpShape);
135+ unmerge.unmerge ({" gemmG" , " numRepeats" }, {0 , 3 }, " gemmG" ,
136+ {newGemmG, numRepeats});
137+ unmerge.passThrough ({" seqLen" , " headDim" }, {2 , 1 }, {" seqLen" , " headDim" });
138+ auto unmergeAttr = unmerge.get ();
139+ Value matrixUnmerge =
140+ builder.create <rock::TransformOp>(loc, inputTensor, unmergeAttr);
141+
142+ // (gemmG / numRepeats, headDim, seqLen, numRepeats) -> (gemmG / numRepeats,
143+ // headDim, seqLen * numRepeats)
144+ auto merger = rock::BottomUpTMBuilder::above (unmerge, unmergeAttr);
145+ merger.merge (" seqLen" , 2 , {" seqLen" , " numRepeats" });
146+ merger.passThrough (ArrayRef<uint32_t >{0 , 1 }, ArrayRef<uint32_t >{0 , 1 });
147+ auto mergerAttr = merger.get ();
148+ return builder.create <rock::TransformOp>(loc, matrixUnmerge, mergerAttr);
149+ }
150+
151+ // Same as moveNumHeadsToSeqLenQ() but for currSeqLen tensor (KV-Cache)
152+ static Value moveNumHeadsToSeqLenCurrSeqLen (OpBuilder builder, Location loc,
153+ Value inputTensor,
154+ int64_t numRepeats) {
155+ ArrayRef<int64_t > inpShape =
156+ cast<ShapedType>(inputTensor.getType ()).getShape ();
157+
158+ assert (inpShape.size () == 1 && " input must be 1D" );
159+ assert (inpShape[0 ] % numRepeats == 0 &&
160+ " gemmG must be divisible by numRepeats" );
161+
162+ int64_t newGemmG = inpShape[0 ] / numRepeats;
163+ SmallVector<StringRef> startNames = {" gemmG" };
164+
165+ // (gemmG) -> (gemmG / numRepeats, numRepeats)
166+ rock::BottomUpTMBuilder unmerge (builder, startNames, inpShape);
167+ unmerge.unmerge ({" gemmG" , " numRepeats" }, {0 , 1 }, " gemmG" ,
168+ {newGemmG, numRepeats});
169+ auto unmergeAttr = unmerge.get ();
170+ Value matrixUnmerge =
171+ builder.create <rock::TransformOp>(loc, inputTensor, unmergeAttr);
172+
173+ // slice numRepeats to 1
174+ auto slicer = rock::BottomUpTMBuilder::above (unmerge, unmergeAttr);
175+ slicer.slice ({" numRepeats" }, {" numRepeats" }, {0 }, {1 });
176+ slicer.passThrough (ArrayRef<uint32_t >{0 }, ArrayRef<uint32_t >{0 });
177+ auto slicerAttr = slicer.get ();
178+ Value matrixSliced =
179+ builder.create <rock::TransformOp>(loc, matrixUnmerge, slicerAttr);
180+
181+ // (gemmG / numRepeats, headDim, seqLen, numRepeats) -> (gemmG / numRepeats,
182+ // headDim, seqLen * numRepeats)
183+ auto merger = rock::BottomUpTMBuilder::above (slicer, slicerAttr);
184+ merger.merge (" seqLen" , 0 , {" gemmG" , " numRepeats" });
185+ auto mergerAttr = merger.get ();
186+ return builder.create <rock::TransformOp>(loc, matrixSliced, mergerAttr);
187+ }
188+
189+ // Same as moveNumHeadsToSeqLenQ() but for the output tensor
190+ static Value moveNumHeadsToSeqLenOut (OpBuilder builder, Location loc,
191+ Value inputTensor, int64_t numRepeats,
192+ int64_t splitKV) {
193+ ArrayRef<int64_t > inpShape =
194+ cast<ShapedType>(inputTensor.getType ()).getShape ();
195+
196+ assert ((inpShape.size () == 2 || inpShape.size () == 3 ) &&
197+ " input must be 2D or 3D" );
198+ assert (inpShape[0 ] % numRepeats == 0 &&
199+ " gemmG must be divisible by numRepeats" );
200+ assert (inpShape[0 ] % splitKV == 0 && " gemmG must be divisible by numRepeats" );
201+
202+ int64_t newGemmG = inpShape[0 ] / (numRepeats * splitKV);
203+ bool isLSE = inpShape.size () == 2 ;
204+
205+ SmallVector<StringRef> startNamesAll = {" gemmG" , " seqLen" , " headDim" };
206+ ArrayRef<StringRef> startNames =
207+ ArrayRef<StringRef>(startNamesAll).take_front (inpShape.size ());
208+
209+ // Note that for LSE, there are only two dimensions (gemmG, seqLen)
210+ // (gemmG, seqLen, headDim) -> (gemmG / (splitKV*numRepeats), splitKV, seqLen,
211+ // numRepeats, headDim)
212+ rock::BottomUpTMBuilder unmerge (builder, startNames, inpShape);
213+ unmerge.unmerge ({" gemmG" , " numRepeats" , " splitKV" }, {0 , 3 , 1 }, " gemmG" ,
214+ {newGemmG, numRepeats, splitKV});
215+ if (isLSE)
216+ unmerge.passThrough ({" seqLen" }, {2 }, {" seqLen" });
217+ else
218+ unmerge.passThrough ({" seqLen" , " headDim" }, {2 , 4 }, {" seqLen" , " headDim" });
219+ auto unmergeAttr = unmerge.get ();
220+ Value matrixUnmerge =
221+ builder.create <rock::TransformOp>(loc, inputTensor, unmergeAttr);
222+
223+ // (gemmG / (splitKV*numRepeats), splitKV, seqLen, numRepeats, headDim) ->
224+ // (gemmG / numRepeats, seqLen * numRepeats, headDim)
225+ auto merger = rock::BottomUpTMBuilder::above (unmerge, unmergeAttr);
226+ merger.merge (" seqLen" , 1 , {" seqLen" , " numRepeats" });
227+ merger.merge (" gemmG" , 0 , {" gemmG" , " splitKV" });
228+ if (!isLSE)
229+ merger.passThrough ({" headDim" }, {2 }, {" headDim" });
230+ auto mergerAttr = merger.get ();
231+ return builder.create <rock::TransformOp>(loc, matrixUnmerge, mergerAttr);
232+ }
233+
234+ // This function will implement GQA, moving numRepeat=num_heads_q/num_heads_kv
235+ // to the seq_len_q dimension. See moveNumHeadsToSeqLenQ() comment for more
236+ // details.
237+ static std::tuple<IntegerAttr, Value, Value, Value, Value, Value, Value>
238+ processGQA (ConversionPatternRewriter &rw, Location loc, Value queries,
239+ Value keys, Value values, Value out, Value lse, Value currentSeqLen,
240+ int64_t numHeadsQ, int64_t numHeadsKV, int64_t splitKV) {
241+ assert (numHeadsQ % numHeadsKV == 0 );
242+ IntegerAttr numRepeatsAttr = nullptr ;
243+
244+ if (numHeadsQ != numHeadsKV) {
245+ int64_t numRepeats = numHeadsQ / numHeadsKV;
246+
247+ numRepeatsAttr = rw.getIndexAttr (numRepeats);
248+ queries = moveNumHeadsToSeqLenQ (rw, loc, queries, numRepeats);
249+ if (currentSeqLen)
250+ currentSeqLen =
251+ moveNumHeadsToSeqLenCurrSeqLen (rw, loc, currentSeqLen, numRepeats);
252+ out = moveNumHeadsToSeqLenOut (rw, loc, out, numRepeats, splitKV);
253+ if (lse)
254+ lse = moveNumHeadsToSeqLenOut (rw, loc, lse, numRepeats, splitKV);
255+ }
256+
257+ return std::make_tuple (numRepeatsAttr, queries, keys, values, out, lse,
258+ currentSeqLen);
259+ }
260+
111261template <typename Op>
112262static LogicalResult
113263computeGridSizeAttentionGemmElmtGemm (ConversionPatternRewriter &rw, Op op,
@@ -314,6 +464,7 @@ static LogicalResult commonAttentionGemmElmtGemm(
314464 Value b, Value c, Value out, Value lse, Value currentSeqLen,
315465 UnitAttr causal, IntegerAttr splitKV, ValueRange elementwiseInputs,
316466 Region &preSecondOpRegion, bool enableSoftmax, TypeAttr softmaxType,
467+ int64_t numHeadsQ, int64_t numHeadsKV,
317468 std::optional<std::reference_wrapper<const BufferDependencyAnalysis>>
318469 bufferDeps) {
319470 Location loc = op->getLoc ();
@@ -363,6 +514,15 @@ static LogicalResult commonAttentionGemmElmtGemm(
363514 std::tie (a, b, c, out) = maybeSplitk.value ();
364515 }
365516
517+ int64_t splitKVNum = splitKV.getInt ();
518+
519+ // Grouped-Query Attention (GQA)
520+ IntegerAttr numRepeatsGQA = nullptr ;
521+ if (enableSoftmax)
522+ std::tie (numRepeatsGQA, a, b, c, out, lse, currentSeqLen) =
523+ processGQA (rw, op.getLoc (), a, b, c, out, lse, currentSeqLen, numHeadsQ,
524+ numHeadsKV, splitKVNum);
525+
366526 // Note, matrix dimension correctness is handled in the verifier
367527 ArrayRef<int64_t > aShape = cast<MemRefType>(a.getType ()).getShape ();
368528 ArrayRef<int64_t > bShape = cast<MemRefType>(b.getType ()).getShape ();
@@ -374,7 +534,6 @@ static LogicalResult commonAttentionGemmElmtGemm(
374534 GemmSize gemm1Size (/* g=*/ aShape[0 ], /* m=*/ cShape[2 ],
375535 /* k=*/ cShape[1 ],
376536 /* n=*/ aShape[2 ]);
377- int64_t splitKVNum = splitKV.getInt ();
378537 GemmSize gemm0ExtraPad = requiredPadding (params0, gemm0Size, 1 , splitKVNum)
379538 .value_or (GemmSize{0 , 0 , 0 , 0 });
380539 GemmSize gemm1ExtraPad = requiredPadding (params1, gemm1Size, splitKVNum)
@@ -417,8 +576,9 @@ static LogicalResult commonAttentionGemmElmtGemm(
417576 loc, a, b, c, elementwiseInputs, currentSeqLen, out, lse, causal, splitKV,
418577 op.getGemmFeaturesAttr (), op.getStoreMethodAttr (), blockSizeAttr,
419578 gridSizeAttr,
420- /* disableQBypassLDS=*/ nullptr , prePadG0MAttr, prePadG0NAttr, softmaxType,
421- params0, params1, rw.getDenseI64ArrayAttr (op.getFirstGemmIndices ()),
579+ /* disableQBypassLDS=*/ nullptr , prePadG0MAttr, prePadG0NAttr,
580+ numRepeatsGQA, softmaxType, params0, params1,
581+ rw.getDenseI64ArrayAttr (op.getFirstGemmIndices ()),
422582 rw.getBoolAttr (enableSoftmax));
423583 bool linalgOpFound = false ;
424584 preSecondOpRegion.walk (
@@ -777,7 +937,8 @@ AttentionRewritePattern::matchAndRewrite(AttentionOp op,
777937 adaptor.getOut (), adaptor.getLse (), adaptor.getCurrentSeqLen (),
778938 adaptor.getCausalAttr (), adaptor.getSplitKVAttr (),
779939 adaptor.getPreSoftmaxElemWiseInputs (), op.getPreSoftmaxBody (),
780- /* enableSoftmax=*/ true , op.getSoftmaxTypeAttr (),
940+ /* enableSoftmax=*/ true , op.getSoftmaxTypeAttr (), adaptor.getNumHeadsQ (),
941+ adaptor.getNumHeadsKV (),
781942 /* bufferDeps=*/ std::nullopt );
782943}
783944
@@ -790,7 +951,8 @@ LogicalResult GemmElementwiseGemmRewritePattern::matchAndRewrite(
790951 /* lse=*/ nullptr ,
791952 /* currentSeqLen=*/ nullptr , /* causal=*/ nullptr , splitKV,
792953 adaptor.getElemwiseInputs (), op.getPreSecondGemmBody (),
793- /* enableSoftmax=*/ false , /* softmaxType=*/ nullptr , std::cref (bufferDeps));
954+ /* enableSoftmax=*/ false , /* softmaxType=*/ nullptr , /* numHeadsQ=*/ 1 ,
955+ /* numHeadsKV=*/ 1 , std::cref (bufferDeps));
794956}
795957
796958void RockGemmToGridwisePass::runOnOperation () {
0 commit comments