@@ -92,19 +92,12 @@ struct GemmElementwiseGemmRewritePattern
9292 LogicalResult matchAndRewrite (GemmElementwiseGemmOp op,
9393 GemmElementwiseGemmOpAdaptor adaptor,
9494 ConversionPatternRewriter &rw) const override ;
95-
96- LogicalResult computeGridSize (ConversionPatternRewriter &rw,
97- GemmElementwiseGemmOp op, Value a, Value b,
98- Value c) const ;
9995};
10096
10197struct AttentionRewritePattern : public OpConversionPattern <AttentionOp> {
10298 using OpConversionPattern<AttentionOp>::OpConversionPattern;
10399 LogicalResult matchAndRewrite (AttentionOp op, AttentionOpAdaptor adaptor,
104100 ConversionPatternRewriter &rw) const override ;
105-
106- LogicalResult computeGridSize (ConversionPatternRewriter &rw, AttentionOp op,
107- Value queries, Value keys, Value values) const ;
108101};
109102
110103template <typename Op>
@@ -139,7 +132,7 @@ computeGridSizeAttentionGemmElmtGemm(ConversionPatternRewriter &rw, Op op,
139132static LogicalResult
140133commonAttentionGemmElmtGemm (ConversionPatternRewriter &rw,
141134 RockGemmGemmWrapperInterface op, Value a, Value b,
142- Value c, Value out, Value currentSeqLen,
135+ Value c, Value out, Value lse, Value currentSeqLen,
143136 UnitAttr causal, ValueRange elementwiseInputs,
144137 Region &preSecondOpRegion, bool enableSoftmax) {
145138 Location loc = op->getLoc ();
@@ -150,17 +143,17 @@ commonAttentionGemmElmtGemm(ConversionPatternRewriter &rw,
150143 bool isAccel = rock::isAccel (op.getGemmFeatures ());
151144 if (!isAccel) {
152145 return op.emitError (" Currently, op is only supported on GPUs "
153- " with matrix accelerator extentions " );
146+ " with matrix accelerator extensions " );
154147 }
155148 if (!op.getGemm0Params ().has_value ()) {
156149 return op.emitError (" gemm0 params is missing and it should've been "
157- " assigned by affix-tuing -params" );
150+ " assigned by affix-tuning -params" );
158151 }
159152 RockAccelTuningParamAttrInterface params0 =
160153 cast<RockAccelTuningParamAttrInterface>(op.getGemm0Params ().value ());
161154 if (!op.getGemm1Params ().has_value ()) {
162155 return op.emitError (" gemm1 params is missing and it should've been "
163- " assigned by affix-tuing -params" );
156+ " assigned by affix-tuning -params" );
164157 }
165158 RockAccelTuningParamAttrInterface params1 =
166159 cast<RockAccelTuningParamAttrInterface>(op.getGemm1Params ().value ());
@@ -177,6 +170,7 @@ commonAttentionGemmElmtGemm(ConversionPatternRewriter &rw,
177170 ArrayRef<int64_t > aShape = cast<MemRefType>(a.getType ()).getShape ();
178171 ArrayRef<int64_t > bShape = cast<MemRefType>(b.getType ()).getShape ();
179172 ArrayRef<int64_t > cShape = cast<MemRefType>(c.getType ()).getShape ();
173+ assert (cShape[1 ] == bShape[2 ]);
180174 GemmSize gemm0Size (/* g=*/ aShape[0 ], /* m=*/ bShape[2 ],
181175 /* k=*/ aShape[1 ],
182176 /* n=*/ aShape[2 ]);
@@ -200,6 +194,8 @@ commonAttentionGemmElmtGemm(ConversionPatternRewriter &rw,
200194 // fusions legit. So the extra pad needs to be swapped and applied.
201195 out = padMatrix (out, rw, loc, " gemm1N" , gemm1ExtraPad.n , " gemm1M" ,
202196 gemm1ExtraPad.m );
197+ if (lse)
198+ lse = padVector (lse, rw, loc, " gemm1N" , gemm1ExtraPad.n );
203199
204200 if (failed (computeGridSizeAttentionGemmElmtGemm (rw, op, a, b, c))) {
205201 return op.emitError (" failed to compute the grid size of "
@@ -218,7 +214,7 @@ commonAttentionGemmElmtGemm(ConversionPatternRewriter &rw,
218214 prePadG0NAttr = rw.getIndexAttr (gemm0Size.n );
219215 }
220216 auto newOp = rw.create <GridwiseAttentionAccelOp>(
221- loc, a, b, c, elementwiseInputs, currentSeqLen, out, causal,
217+ loc, a, b, c, elementwiseInputs, currentSeqLen, out, lse, causal,
222218 rw.getStringAttr (op.getArch ()),
223219 rw.getAttr <rock::GemmFeaturesAttr>(op.getGemmFeatures ()), blockSizeAttr,
224220 gridSizeAttr,
@@ -585,34 +581,23 @@ AttentionRewritePattern::matchAndRewrite(AttentionOp op,
585581 ConversionPatternRewriter &rw) const {
586582 return commonAttentionGemmElmtGemm (
587583 rw, op, adaptor.getQueries (), adaptor.getKeys (), adaptor.getValues (),
588- adaptor.getOut (), adaptor.getCurrentSeqLen (), adaptor.getCausalAttr (),
589- adaptor.getPreSoftmaxElemWiseInputs (), op.getPreSoftmaxBody (),
584+ adaptor.getOut (), adaptor.getLse (), adaptor.getCurrentSeqLen (),
585+ adaptor.getCausalAttr (), adaptor.getPreSoftmaxElemWiseInputs (),
586+ op.getPreSoftmaxBody (),
590587 /* enableSoftmax=*/ true );
591588}
592589
593- LogicalResult
594- AttentionRewritePattern::computeGridSize (ConversionPatternRewriter &rw,
595- AttentionOp op, Value queries,
596- Value keys, Value values) const {
597- return computeGridSizeAttentionGemmElmtGemm (rw, op, queries, keys, values);
598- }
599-
600590LogicalResult GemmElementwiseGemmRewritePattern::matchAndRewrite (
601591 GemmElementwiseGemmOp op, GemmElementwiseGemmOpAdaptor adaptor,
602592 ConversionPatternRewriter &rw) const {
603593 return commonAttentionGemmElmtGemm (
604594 rw, op, adaptor.getA (), adaptor.getB (), adaptor.getC (), adaptor.getOut (),
595+ /* lse=*/ nullptr ,
605596 /* currentSeqLen=*/ nullptr , /* causal=*/ nullptr ,
606597 adaptor.getElemwiseInputs (), op.getPreSecondGemmBody (),
607598 /* enableSoftmax=*/ false );
608599}
609600
610- LogicalResult GemmElementwiseGemmRewritePattern::computeGridSize (
611- ConversionPatternRewriter &rw, GemmElementwiseGemmOp op, Value a, Value b,
612- Value c) const {
613- return computeGridSizeAttentionGemmElmtGemm (rw, op, a, b, c);
614- }
615-
616601void RockGemmToGridwisePass::runOnOperation () {
617602 MLIRContext *ctx = &getContext ();
618603 ConversionTarget target (*ctx);
0 commit comments