Skip to content

Commit 3cc6409

Browse files
committed
add decompose gatherOp transform
1 parent 9978725 commit 3cc6409

File tree

5 files changed

+210
-3
lines changed

5 files changed

+210
-3
lines changed

include/gc/Transforms/Passes.td

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def MergeAlloc : Pass<"gc-merge-alloc", "func::FuncOp"> {
3030
lifetime of the original memref before merging. This pass schedules the
3131
offsets to 1) make sure the offsets and address ranges do not overlap if
3232
two "mergeable" allocations have overlapped lifetime, and 2) reuse the
33-
address ranges that are considered "hot" in cache for an later allocation.
33+
address ranges that are considered "hot" in cache for an later allocation.
3434
}];
3535
let options = [
3636
Option<"optionAnalysisOnly", "analysis-only", "bool",
@@ -201,6 +201,16 @@ def FoldTensorOperation : Pass<"fold-tensor-operation"> {
201201
let description = [{
202202
Remove some useless tensor operations.
203203
}];
204+
let dependentDialects = [
205+
"tensor::TensorDialect",
206+
];
207+
}
208+
209+
def DecomposeTensorOperation : Pass<"decompose-tensor-operation"> {
210+
let summary = "decompose some tensor operation";
211+
let description = [{
212+
Decompose some tensor operations(concat, gather) into linalg operation.
213+
}];
204214
let dependentDialects = [
205215
"::mlir::tensor::TensorDialect"
206216
];

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ gc_add_mlir_library(GcPasses
2525
MergeAlloc.cpp
2626
MergeAllocTickBased.cpp
2727
FoldTensorOperation.cpp
28+
DecomposeTensorOperation.cpp
2829
LowerToTileVector.cpp
2930

3031
DEPENDS
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
//===-- DecomposeTensorOperation.cpp - DESC ---------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#include "gc/Transforms/Passes.h"
9+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
10+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
11+
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
12+
#include "mlir/IR/PatternMatch.h"
13+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14+
#include "llvm/Support/Casting.h"
15+
16+
namespace mlir {
17+
namespace gc {
18+
19+
#define GEN_PASS_DEF_DECOMPOSETENSOROPERATION
20+
#include "gc/Transforms/Passes.h.inc"
21+
namespace {
22+
23+
/// Decompose `tensor.gather` into `linalg.generic`.
24+
///
25+
/// %2 = tensor.gather %0[%1] gather_dims([0]) : (tensor<7x128xf16>,
26+
/// tensor<1x7x1xindex>) -> tensor<1x7x128xf16>
27+
///
28+
/// Becomes
29+
///
30+
/// %empty = tensor.empty() : tensor<1x7x128xf16>
31+
/// %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1,
32+
/// 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
33+
/// ["parallel", "parallel", "parallel"]} ins(%expanded : tensor<1x7x1xindex>)
34+
/// outs(%13 : tensor<1x7x128xf16>) {
35+
/// ^bb0(%in: index, %out: f16):
36+
/// %17 = linalg.index 2 : index
37+
/// %extracted = tensor.extract %0[%in, %17] : tensor<7x128xf16>
38+
/// linalg.yield %extracted : f16
39+
/// } -> tensor<1x7x128xf16>
40+
41+
struct DecomposeGatherOp : public OpRewritePattern<tensor::GatherOp> {
42+
using OpRewritePattern<tensor::GatherOp>::OpRewritePattern;
43+
44+
LogicalResult matchAndRewrite(tensor::GatherOp gatherOp,
45+
PatternRewriter &rewriter) const override {
46+
OpBuilder::InsertionGuard g(rewriter);
47+
rewriter.setInsertionPoint(gatherOp);
48+
Location loc = gatherOp.getLoc();
49+
SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());
50+
51+
// create destination tensor for linalg out
52+
RankedTensorType dstType = gatherOp.getResultType();
53+
Value dstTensor = rewriter.create<tensor::EmptyOp>(
54+
loc, tensor::getMixedSizes(rewriter, loc, gatherOp.getResult()),
55+
dstType.getElementType());
56+
57+
// split index tensor to create the linalg input
58+
SmallVector<Value> indexTensors;
59+
Value originIndexTensor = gatherOp.getIndices();
60+
SmallVector<OpFoldResult> indexTensorSize =
61+
tensor::getMixedSizes(rewriter, loc, originIndexTensor);
62+
SmallVector<OpFoldResult> indexTensorStride(
63+
indexTensorSize.size(),
64+
getAsIndexOpFoldResult(rewriter.getContext(), 1));
65+
SmallVector<OpFoldResult> indexTensorOffset(
66+
indexTensorSize.size(),
67+
getAsIndexOpFoldResult(rewriter.getContext(), 0));
68+
indexTensorSize[indexTensorSize.size() - 1] =
69+
getAsIndexOpFoldResult(rewriter.getContext(), 1);
70+
71+
for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
72+
indexTensorOffset[indexTensorSize.size() - 1] =
73+
getAsIndexOpFoldResult(rewriter.getContext(), cnt);
74+
Value indexTensor = rewriter.create<tensor::ExtractSliceOp>(
75+
loc, originIndexTensor, indexTensorOffset, indexTensorSize,
76+
indexTensorStride);
77+
indexTensors.emplace_back(indexTensor);
78+
}
79+
80+
// create the affine map
81+
SmallVector<AffineMap> affineMaps;
82+
SmallVector<AffineExpr> dimExprs;
83+
size_t dstRank = dstType.getShape().size();
84+
for (unsigned i = 0; i < indexTensorSize.size() - 1; ++i)
85+
dimExprs.push_back(rewriter.getAffineDimExpr(i));
86+
dimExprs.push_back(getAffineConstantExpr(0, rewriter.getContext()));
87+
88+
for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
89+
AffineMap currentMap =
90+
AffineMap::get(/*dimCount=*/dstRank, /*symbolCount=*/0, dimExprs,
91+
rewriter.getContext());
92+
affineMaps.emplace_back(currentMap);
93+
}
94+
affineMaps.emplace_back(rewriter.getMultiDimIdentityMap(dstRank));
95+
96+
// create iterater types array
97+
SmallVector<utils::IteratorType> iteratorTypesArray(
98+
dstRank, utils::IteratorType::parallel);
99+
100+
// check whether the gather op is valid
101+
size_t srcRank = gatherOp.getSourceType().getShape().size();
102+
assert(((indexTensorSize.size() - 1) + srcRank == dstRank ||
103+
(indexTensorSize.size() - 1) + srcRank ==
104+
dstRank + gatherDims.size()) &&
105+
"Expected: index_size - 1 + source_size == dst_size or dst_szie - "
106+
"gather_size. \n");
107+
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
108+
gatherOp, TypeRange(dstType), indexTensors, ValueRange{dstTensor},
109+
affineMaps, iteratorTypesArray,
110+
[&](OpBuilder &b, Location loc, ValueRange args) {
111+
SmallVector<Value> indexValues(srcRank);
112+
bool isShrinkDst = (indexTensorSize.size() - 1) + srcRank ==
113+
dstRank + gatherDims.size();
114+
int cnt = 0;
115+
for (auto i = indexTensorSize.size() - 1; i < dstRank; i++) {
116+
while (llvm::find(gatherDims, cnt) != gatherDims.end() &&
117+
isShrinkDst) {
118+
cnt++;
119+
}
120+
indexValues[cnt] = b.create<linalg::IndexOp>(loc, i);
121+
cnt++;
122+
}
123+
for (auto &&[i, dim] : llvm::enumerate(gatherDims)) {
124+
indexValues[dim] = args[i];
125+
}
126+
127+
Value extract = b.create<tensor::ExtractOp>(loc, gatherOp.getSource(),
128+
indexValues);
129+
b.create<linalg::YieldOp>(loc, extract);
130+
});
131+
return success();
132+
}
133+
};
134+
135+
/// DecomposeTensorOperationPass is a pass that decompose some tensor
136+
/// operations like tensor.gather, tensor.concat.
137+
struct DecomposeTensorOperationPass
138+
: public impl::DecomposeTensorOperationBase<DecomposeTensorOperationPass> {
139+
void runOnOperation() final {
140+
auto *ctx = &getContext();
141+
RewritePatternSet patterns(ctx);
142+
143+
patterns.add<DecomposeGatherOp>(patterns.getContext());
144+
tensor::populateDecomposeTensorConcatPatterns(patterns);
145+
146+
if (failed(applyPatternsAndFoldGreedily(getOperation(),
147+
std::move(patterns)))) {
148+
return signalPassFailure();
149+
}
150+
}
151+
};
152+
} // namespace
153+
} // namespace gc
154+
} // namespace mlir

lib/gc/Transforms/GPU/Pipeline.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ void populateGPUPipeline(OpPassManager &pm,
3434
// Add an argument for the GPU context
3535
pm.addNestedPass<func::FuncOp>(createAddContextArg());
3636
}
37-
37+
pm.addPass(createDecomposeTensorOperation());
3838
pm.addNestedPass<func::FuncOp>(createIterativeTilingAndFusion());
39-
4039
pm.addPass(bufferization::createEmptyTensorEliminationPass());
4140
pm.addPass(bufferization::createEmptyTensorToAllocTensorPass());
4241

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: gc-opt %s -decompose-tensor-operation --split-input-file | FileCheck %s
2+
3+
/// CHECK-LABEL: @gather_single_gather_dim
4+
func.func @gather_single_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32> {
5+
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2x2xf32>
6+
/// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2x2xf32>)
7+
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32>
8+
return %1 : tensor<2x3x2x2x2xf32>
9+
}
10+
11+
// -----
12+
13+
/// CHECK-LABEL: @gather_single_gather_dim_no_shrink
14+
func.func @gather_single_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32> {
15+
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x1x2x2xf32>
16+
/// CHECK: linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x1x2x2xf32>)
17+
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32>
18+
return %1 : tensor<2x3x2x1x2x2xf32>
19+
}
20+
21+
// -----
22+
23+
/// CHECK-LABEL: @gather_multiple_gather_dim
24+
func.func @gather_multiple_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32> {
25+
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2xf32>
26+
/// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
27+
/// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
28+
/// CHECK: linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<2x3x1xindex>, tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2xf32>)
29+
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32>
30+
return %1 : tensor<2x3x2x2xf32>
31+
}
32+
33+
// -----
34+
35+
/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink
36+
func.func @gather_multiple_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32> {
37+
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x1x1x2xf32>
38+
/// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
39+
/// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
40+
/// CHECK: linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<2x3x1xindex>, tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x1x1x2xf32>)
41+
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32>
42+
return %1 : tensor<2x3x2x1x1x2xf32>
43+
}

0 commit comments

Comments
 (0)