@@ -129,12 +129,11 @@ computeGridSizeAttentionGemmElmtGemm(ConversionPatternRewriter &rw, Op op,
129129 return success ();
130130}
131131
132- static LogicalResult
133- commonAttentionGemmElmtGemm (ConversionPatternRewriter &rw,
134- RockGemmGemmWrapperInterface op, Value a, Value b,
135- Value c, Value out, Value lse, Value currentSeqLen,
136- UnitAttr causal, ValueRange elementwiseInputs,
137- Region &preSecondOpRegion, bool enableSoftmax) {
132+ static LogicalResult commonAttentionGemmElmtGemm (
133+ ConversionPatternRewriter &rw, RockGemmGemmWrapperInterface op, Value a,
134+ Value b, Value c, Value out, Value lse, Value currentSeqLen,
135+ UnitAttr causal, ValueRange elementwiseInputs, Region &preSecondOpRegion,
136+ bool enableSoftmax, TypeAttr softmaxType) {
138137 Location loc = op->getLoc ();
139138
140139 if (!isa<MemRefType>(op.getAType ()))
@@ -218,8 +217,8 @@ commonAttentionGemmElmtGemm(ConversionPatternRewriter &rw,
218217 rw.getStringAttr (op.getArch ()),
219218 rw.getAttr <rock::GemmFeaturesAttr>(op.getGemmFeatures ()), blockSizeAttr,
220219 gridSizeAttr,
221- /* disableQBypassLDS=*/ nullptr , prePadG0MAttr, prePadG0NAttr, params0 ,
222- params1, rw.getI32IntegerAttr (op.getFirstGemmIndex ()),
220+ /* disableQBypassLDS=*/ nullptr , prePadG0MAttr, prePadG0NAttr, softmaxType ,
221+ params0, params1, rw.getI32IntegerAttr (op.getFirstGemmIndex ()),
223222 rw.getBoolAttr (enableSoftmax));
224223 bool linalgOpFound = false ;
225224 preSecondOpRegion.walk (
@@ -584,7 +583,7 @@ AttentionRewritePattern::matchAndRewrite(AttentionOp op,
584583 adaptor.getOut (), adaptor.getLse (), adaptor.getCurrentSeqLen (),
585584 adaptor.getCausalAttr (), adaptor.getPreSoftmaxElemWiseInputs (),
586585 op.getPreSoftmaxBody (),
587- /* enableSoftmax=*/ true );
586+ /* enableSoftmax=*/ true , op. getSoftmaxTypeAttr () );
588587}
589588
590589LogicalResult GemmElementwiseGemmRewritePattern::matchAndRewrite (
@@ -595,7 +594,7 @@ LogicalResult GemmElementwiseGemmRewritePattern::matchAndRewrite(
595594 /* lse=*/ nullptr ,
596595 /* currentSeqLen=*/ nullptr , /* causal=*/ nullptr ,
597596 adaptor.getElemwiseInputs (), op.getPreSecondGemmBody (),
598- /* enableSoftmax=*/ false );
597+ /* enableSoftmax=*/ false , /* softmaxType= */ nullptr );
599598}
600599
601600void RockGemmToGridwisePass::runOnOperation () {
0 commit comments