Skip to content

Commit c58903c

Browse files
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 7107cdb commit c58903c

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyTilingLevel.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,15 @@ static LogicalResult applyTileAndFuseToEachRoot(
132132

133133
if (tilingLevel == IREE::GPU::TilingLevel::PartialReduction) {
134134
tilingOptions.setReductionTilingStrategy(
135-
scf::SCFTilingOptions::ReductionTilingStrategy::
136-
PartialReductionOuterReduction);
135+
ReductionTilingStrategy::PartialReductionOuterReduction);
136+
SmallVector<unsigned> reductionDims;
137+
for (auto [index, iteratorType] :
138+
llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
139+
if (iteratorType == utils::IteratorType::reduction) {
140+
reductionDims.push_back(index);
141+
}
142+
}
143+
tilingOptions.setReductionDims(reductionDims);
137144
}
138145

139146
scf::SCFTileAndFuseOptions tileAndFuseOptions;

compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2501,7 +2501,7 @@ static AffineMap getPartialResultMap(AffineMap map, AttentionOpDetail &opInfo) {
25012501
FailureOr<SmallVector<Value>>
25022502
OnlineAttentionOp::generateInitialTensorForPartialReduction(
25032503
OpBuilder &b, Location loc, ArrayRef<OpFoldResult> sizes,
2504-
ArrayRef<int> reductionDim) {
2504+
const llvm::SetVector<unsigned> &reductionDims) {
25052505
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
25062506
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());
25072507
if (failed(maybeOpInfo)) {
@@ -2555,8 +2555,10 @@ OnlineAttentionOp::generateInitialTensorForPartialReduction(
25552555
}
25562556

25572557
FailureOr<TilingResult> OnlineAttentionOp::tileToPartialReduction(
2558-
OpBuilder &b, Location loc, ValueRange init, ArrayRef<OpFoldResult> offsets,
2559-
ArrayRef<OpFoldResult> sizes, ArrayRef<int> reductionDims) {
2558+
OpBuilder &b, Location loc, ReductionTilingStrategy strategy,
2559+
ValueRange init, ArrayRef<OpFoldResult> offsets,
2560+
ArrayRef<OpFoldResult> sizes,
2561+
const llvm::SetVector<unsigned> &reductionDims) {
25602562
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
25612563
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());
25622564
if (failed(maybeOpInfo)) {
@@ -2707,10 +2709,9 @@ static Value computeSubAndExp2(OpBuilder &builder, Location loc,
27072709
return genericOp.getResult(0);
27082710
}
27092711

2710-
FailureOr<MergeResult>
2711-
OnlineAttentionOp::mergeReductions(OpBuilder &b, Location loc,
2712-
ValueRange partialReduce,
2713-
ArrayRef<int> reductionDim) {
2712+
FailureOr<MergeResult> OnlineAttentionOp::mergeReductions(
2713+
OpBuilder &b, Location loc, ValueRange partialReduce,
2714+
const llvm::SetVector<unsigned> &reductionDims) {
27142715
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
27152716
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());
27162717
if (failed(maybeOpInfo)) {
@@ -2753,8 +2754,10 @@ OnlineAttentionOp::mergeReductions(OpBuilder &b, Location loc,
27532754

27542755
LogicalResult OnlineAttentionOp::getPartialResultTilePosition(
27552756
OpBuilder &b, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2756-
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2757-
SmallVector<OpFoldResult> &resultSizes, ArrayRef<int> reductionDims) {
2757+
ArrayRef<OpFoldResult> sizes,
2758+
const llvm::SetVector<unsigned> &reductionDims,
2759+
SmallVector<OpFoldResult> &resultOffsets,
2760+
SmallVector<OpFoldResult> &resultSizes) {
27582761

27592762
FailureOr<AttentionOpDetail> maybeOpInfo = AttentionOpDetail::get(
27602763
getQueryMap(), getKeyMap(), getValueMap(), getOutputMap());

0 commit comments

Comments
 (0)