Skip to content

Commit 74eb560

Browse files
authored
Merge branch 'develop' into upstream_merge_55
2 parents 966a9b3 + 9316d7b commit 74eb560

22 files changed

+359
-146
lines changed

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ def Rock_AttentionOp
220220
Optional<TensorOrMemRefOf<[F32, F16, BF16]>>:$lse,
221221
UnitAttr:$qTransposed, UnitAttr:$kTransposed, UnitAttr:$vTransposed,
222222
UnitAttr:$oTransposed, UnitAttr:$causal, StrAttr:$arch,
223-
Rock_GemmFeaturesAttr:$features, OptionalAttr<I32Attr>:$numCU,
223+
Rock_GemmFeaturesAttr:$features, OptionalAttr<TypeAttr>:$softmaxType,
224+
OptionalAttr<I32Attr>:$numCU,
224225
OptionalAttr<RockTuningParamAttrInterface>:$params0,
225226
OptionalAttr<RockTuningParamAttrInterface>:$params1,
226227
I32Attr:$firstGemmIdx)>,
@@ -568,6 +569,7 @@ def Rock_GridwiseAttentionAccelOp
568569
I32Attr:$gridSize, UnitAttr:$disableQBypassLDS,
569570
OptionalAttr<IndexAttr>:$prePadG0M,
570571
OptionalAttr<IndexAttr>:$prePadG0N,
572+
OptionalAttr<TypeAttr>:$softmaxType,
571573
RockAccelTuningParamAttrInterface:$params0,
572574
RockAccelTuningParamAttrInterface:$params1, I32Attr:$firstGemmIdx,
573575
DefaultValuedOptionalAttr<BoolAttr, "true">:$enableSoftmax)> {

mlir/include/mlir/Dialect/Rock/utility/builderUtils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,16 @@ Value createZeroConstantOp(OpBuilder &b, Location loc, Type type);
3232
Value createTypeConversionOp(OpBuilder &b, Location loc, Value source,
3333
Type destType);
3434

35+
// Utility function to perform cast
36+
// and copy to another memref using a Linalg Generic.
3537
void createTypeConversionLaGeneric(PatternRewriter &rewriter, Location loc,
3638
Value src, Value dst);
39+
40+
// Utility function to perform cast
41+
// and copy to another memref using a vector store. This flattens the vectors.
42+
void createTypeConversionFlatAndStore(PatternRewriter &rewriter, Location loc,
43+
Value src, Value dst);
44+
3745
/// Utility function to collapse an multi-dimensional memref to 1D.
3846
Value createCollapseShapeOp(OpBuilder &b, Location loc, Value source);
3947

mlir/include/mlir/Dialect/Rock/utility/loweringUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ AffineMap getIdxReversalMap(OpBuilder &b);
196196
// helper to create ReassociationIndices for flattening
197197
ReassociationIndices getReassociationForFlattening(ShapedType srcTp);
198198

199+
// helper to obtain a flattened memref
200+
Value getFlattenedMemref(OpBuilder &b, Value nonFlatMemRef);
201+
199202
/// Construct a `memref.view` operation that interprets the buffer `buffer`,
200203
/// whose elements are bytes, as a buffer of `type`.
201204
TypedValue<MemRefType> viewBufferAs(OpBuilder &b, Value buffer, Type type);

mlir/lib/Conversion/TosaToRock/TosaToRock.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1843,7 +1843,8 @@ struct AttentionRewritePattern : public OpRewritePattern<tosa::MatMulOp> {
18431843
/*kTransposed=*/nullptr,
18441844
/*vTransposed=*/nullptr,
18451845
/*oTransposed=*/nullptr, causalAttr, arch,
1846-
rewriter.getAttr<rock::GemmFeaturesAttr>(features), numCUAttr,
1846+
rewriter.getAttr<rock::GemmFeaturesAttr>(features),
1847+
/*softmaxType=*/nullptr, numCUAttr,
18471848
/*params0=*/nullptr, /*params1=*/nullptr,
18481849
/*firstGemmIdx=*/rewriter.getI32IntegerAttr(0));
18491850

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,6 +1876,10 @@ LogicalResult GridwiseAttentionAccelOp::verify() {
18761876
return emitError("LSE only works for attention.");
18771877
}
18781878

1879+
if (!getEnableSoftmax() && getSoftmaxType()) {
1880+
return emitError("Setting softmax type only works for attention.");
1881+
}
1882+
18791883
int64_t linalgOpCount = 0;
18801884
getPreSoftmaxBody().walk([&](linalg::GenericOp genOp) { linalgOpCount++; });
18811885
if (linalgOpCount > 1) {

mlir/lib/Dialect/Rock/Transforms/GemmToGridwise.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,11 @@ computeGridSizeAttentionGemmElmtGemm(ConversionPatternRewriter &rw, Op op,
129129
return success();
130130
}
131131

132-
static LogicalResult
133-
commonAttentionGemmElmtGemm(ConversionPatternRewriter &rw,
134-
RockGemmGemmWrapperInterface op, Value a, Value b,
135-
Value c, Value out, Value lse, Value currentSeqLen,
136-
UnitAttr causal, ValueRange elementwiseInputs,
137-
Region &preSecondOpRegion, bool enableSoftmax) {
132+
static LogicalResult commonAttentionGemmElmtGemm(
133+
ConversionPatternRewriter &rw, RockGemmGemmWrapperInterface op, Value a,
134+
Value b, Value c, Value out, Value lse, Value currentSeqLen,
135+
UnitAttr causal, ValueRange elementwiseInputs, Region &preSecondOpRegion,
136+
bool enableSoftmax, TypeAttr softmaxType) {
138137
Location loc = op->getLoc();
139138

140139
if (!isa<MemRefType>(op.getAType()))
@@ -218,8 +217,8 @@ commonAttentionGemmElmtGemm(ConversionPatternRewriter &rw,
218217
rw.getStringAttr(op.getArch()),
219218
rw.getAttr<rock::GemmFeaturesAttr>(op.getGemmFeatures()), blockSizeAttr,
220219
gridSizeAttr,
221-
/*disableQBypassLDS=*/nullptr, prePadG0MAttr, prePadG0NAttr, params0,
222-
params1, rw.getI32IntegerAttr(op.getFirstGemmIndex()),
220+
/*disableQBypassLDS=*/nullptr, prePadG0MAttr, prePadG0NAttr, softmaxType,
221+
params0, params1, rw.getI32IntegerAttr(op.getFirstGemmIndex()),
223222
rw.getBoolAttr(enableSoftmax));
224223
bool linalgOpFound = false;
225224
preSecondOpRegion.walk(
@@ -584,7 +583,7 @@ AttentionRewritePattern::matchAndRewrite(AttentionOp op,
584583
adaptor.getOut(), adaptor.getLse(), adaptor.getCurrentSeqLen(),
585584
adaptor.getCausalAttr(), adaptor.getPreSoftmaxElemWiseInputs(),
586585
op.getPreSoftmaxBody(),
587-
/*enableSoftmax=*/true);
586+
/*enableSoftmax=*/true, op.getSoftmaxTypeAttr());
588587
}
589588

590589
LogicalResult GemmElementwiseGemmRewritePattern::matchAndRewrite(
@@ -595,7 +594,7 @@ LogicalResult GemmElementwiseGemmRewritePattern::matchAndRewrite(
595594
/*lse=*/nullptr,
596595
/*currentSeqLen=*/nullptr, /*causal=*/nullptr,
597596
adaptor.getElemwiseInputs(), op.getPreSecondGemmBody(),
598-
/*enableSoftmax=*/false);
597+
/*enableSoftmax=*/false, /*softmaxType=*/nullptr);
599598
}
600599

601600
void RockGemmToGridwisePass::runOnOperation() {

0 commit comments

Comments
 (0)