Skip to content

Commit eeee7b5

Browse files
committed
add decompose gatherOp transform
1 parent fdfbd1e commit eeee7b5

File tree

5 files changed

+211
-3
lines changed

5 files changed

+211
-3
lines changed

include/gc/Transforms/Passes.td

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def MergeAlloc : Pass<"gc-merge-alloc", "func::FuncOp"> {
3030
lifetime of the original memref before merging. This pass schedules the
3131
offsets to 1) make sure the offsets and address ranges do not overlap if
3232
two "mergeable" allocations have overlapped lifetime, and 2) reuse the
33-
address ranges that are considered "hot" in cache for an later allocation.
33+
address ranges that are considered "hot" in cache for an later allocation.
3434
}];
3535
let options = [
3636
Option<"optionAnalysisOnly", "analysis-only", "bool",
@@ -201,6 +201,16 @@ def FoldTensorOperation : Pass<"fold-tensor-operation"> {
201201
let description = [{
202202
Remove some useless tensor operations.
203203
}];
204+
let dependentDialects = [
205+
"tensor::TensorDialect",
206+
];
207+
}
208+
209+
def DecomposeTensorOperation : Pass<"decompose-tensor-operation"> {
210+
let summary = "decompose some tensor operation";
211+
let description = [{
212+
Decompose some tensor operations(concat, gather) into linalg operation.
213+
}];
204214
let dependentDialects = [
205215
"::mlir::tensor::TensorDialect"
206216
];

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ gc_add_mlir_library(GcPasses
2525
MergeAlloc.cpp
2626
MergeAllocTickBased.cpp
2727
FoldTensorOperation.cpp
28+
DecomposeTensorOperation.cpp
2829
LowerToTileVector.cpp
2930

3031
DEPENDS
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
//===-- DecomposeTensorOperation.cpp - fold tensor op ----------------*- C++
2+
//-*-===//
3+
//
4+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
#include "gc/Transforms/Passes.h"
10+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
11+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
12+
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15+
#include "llvm/Support/Casting.h"
16+
17+
namespace mlir {
18+
namespace gc {
19+
20+
#define GEN_PASS_DEF_DECOMPOSETENSOROPERATION
21+
#include "gc/Transforms/Passes.h.inc"
22+
namespace {
23+
24+
/// Decompose `tensor.gather` into `linalg.generic`.
25+
///
26+
/// %2 = tensor.gather %0[%1] gather_dims([0]) : (tensor<7x128xf16>,
27+
/// tensor<1x7x1xindex>) -> tensor<1x7x128xf16>
28+
///
29+
/// Becomes
30+
///
31+
/// %empty = tensor.empty() : tensor<1x7x128xf16>
32+
/// %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1,
33+
/// 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
34+
/// ["parallel", "parallel", "parallel"]} ins(%expanded : tensor<1x7x1xindex>)
35+
/// outs(%13 : tensor<1x7x128xf16>) {
36+
/// ^bb0(%in: index, %out: f16):
37+
/// %17 = linalg.index 2 : index
38+
/// %extracted = tensor.extract %0[%in, %17] : tensor<7x128xf16>
39+
/// linalg.yield %extracted : f16
40+
/// } -> tensor<1x7x128xf16>
41+
42+
struct DecomposeGatherOp : public OpRewritePattern<tensor::GatherOp> {
43+
using OpRewritePattern<tensor::GatherOp>::OpRewritePattern;
44+
45+
LogicalResult matchAndRewrite(tensor::GatherOp gatherOp,
46+
PatternRewriter &rewriter) const override {
47+
OpBuilder::InsertionGuard g(rewriter);
48+
rewriter.setInsertionPoint(gatherOp);
49+
Location loc = gatherOp.getLoc();
50+
SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());
51+
52+
// create destination tensor for linalg out
53+
RankedTensorType dstType = gatherOp.getResultType();
54+
Value dstTensor = rewriter.create<tensor::EmptyOp>(
55+
loc, tensor::getMixedSizes(rewriter, loc, gatherOp.getResult()),
56+
dstType.getElementType());
57+
58+
// split index tensor to create the linalg input
59+
SmallVector<Value> indexTensors;
60+
Value originIndexTensor = gatherOp.getIndices();
61+
SmallVector<OpFoldResult> indexTensorSize =
62+
tensor::getMixedSizes(rewriter, loc, originIndexTensor);
63+
SmallVector<OpFoldResult> indexTensorStride(
64+
indexTensorSize.size(),
65+
getAsIndexOpFoldResult(rewriter.getContext(), 1));
66+
SmallVector<OpFoldResult> indexTensorOffset(
67+
indexTensorSize.size(),
68+
getAsIndexOpFoldResult(rewriter.getContext(), 0));
69+
indexTensorSize[indexTensorSize.size() - 1] =
70+
getAsIndexOpFoldResult(rewriter.getContext(), 1);
71+
72+
for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
73+
indexTensorOffset[indexTensorSize.size() - 1] =
74+
getAsIndexOpFoldResult(rewriter.getContext(), cnt);
75+
Value indexTensor = rewriter.create<tensor::ExtractSliceOp>(
76+
loc, originIndexTensor, indexTensorOffset, indexTensorSize,
77+
indexTensorStride);
78+
indexTensors.emplace_back(indexTensor);
79+
}
80+
81+
// create the affine map
82+
SmallVector<AffineMap> affineMaps;
83+
SmallVector<AffineExpr> dimExprs;
84+
size_t dstRank = dstType.getShape().size();
85+
for (unsigned i = 0; i < indexTensorSize.size() - 1; ++i)
86+
dimExprs.push_back(rewriter.getAffineDimExpr(i));
87+
dimExprs.push_back(getAffineConstantExpr(0, rewriter.getContext()));
88+
89+
for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
90+
AffineMap currentMap =
91+
AffineMap::get(/*dimCount=*/dstRank, /*symbolCount=*/0, dimExprs,
92+
rewriter.getContext());
93+
affineMaps.emplace_back(currentMap);
94+
}
95+
affineMaps.emplace_back(rewriter.getMultiDimIdentityMap(dstRank));
96+
97+
// create iterater types array
98+
SmallVector<utils::IteratorType> iteratorTypesArray(
99+
dstRank, utils::IteratorType::parallel);
100+
101+
// check whether the gather op is valid
102+
size_t srcRank = gatherOp.getSourceType().getShape().size();
103+
assert(((indexTensorSize.size() - 1) + srcRank == dstRank ||
104+
(indexTensorSize.size() - 1) + srcRank ==
105+
dstRank + gatherDims.size()) &&
106+
"Expected: index_size - 1 + source_size == dst_size or dst_szie - "
107+
"gather_size. \n");
108+
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
109+
gatherOp, TypeRange(dstType), indexTensors, ValueRange{dstTensor},
110+
affineMaps, iteratorTypesArray,
111+
[&](OpBuilder &b, Location loc, ValueRange args) {
112+
SmallVector<Value> indexValues(srcRank);
113+
bool isShrinkDst = (indexTensorSize.size() - 1) + srcRank ==
114+
dstRank + gatherDims.size();
115+
int cnt = 0;
116+
for (auto i = indexTensorSize.size() - 1; i < dstRank; i++) {
117+
while (llvm::find(gatherDims, cnt) != gatherDims.end() &&
118+
isShrinkDst) {
119+
cnt++;
120+
}
121+
indexValues[cnt] = b.create<linalg::IndexOp>(loc, i);
122+
cnt++;
123+
}
124+
for (auto &&[i, dim] : llvm::enumerate(gatherDims)) {
125+
indexValues[dim] = args[i];
126+
}
127+
128+
Value extract = b.create<tensor::ExtractOp>(loc, gatherOp.getSource(),
129+
indexValues);
130+
b.create<linalg::YieldOp>(loc, extract);
131+
});
132+
return success();
133+
}
134+
};
135+
136+
/// DecomposeTensorOperationPass is a pass that decompose some tensor
137+
/// operations like tensor.gather, tensor.concat.
138+
struct DecomposeTensorOperationPass
139+
: public impl::DecomposeTensorOperationBase<DecomposeTensorOperationPass> {
140+
void runOnOperation() final {
141+
auto *ctx = &getContext();
142+
RewritePatternSet patterns(ctx);
143+
144+
patterns.add<DecomposeGatherOp>(patterns.getContext());
145+
tensor::populateDecomposeTensorConcatPatterns(patterns);
146+
147+
if (failed(applyPatternsAndFoldGreedily(getOperation(),
148+
std::move(patterns)))) {
149+
return signalPassFailure();
150+
}
151+
}
152+
};
153+
} // namespace
154+
} // namespace gc
155+
} // namespace mlir

lib/gc/Transforms/GPU/Pipeline.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ void populateGPUPipeline(OpPassManager &pm,
3434
// Add an argument for the GPU context
3535
pm.addNestedPass<func::FuncOp>(createAddContextArg());
3636
}
37-
37+
pm.addPass(createDecomposeTensorOperation());
3838
pm.addNestedPass<func::FuncOp>(createIterativeTilingAndFusion());
39-
4039
pm.addPass(bufferization::createEmptyTensorEliminationPass());
4140
pm.addPass(bufferization::createEmptyTensorToAllocTensorPass());
4241

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: gc-opt %s -decompose-tensor-operation --split-input-file | FileCheck %s
2+
3+
/// CHECK-LABEL: @gather_single_gather_dim
4+
func.func @gather_single_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32> {
5+
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2x2xf32>
6+
/// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2x2xf32>)
7+
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32>
8+
return %1 : tensor<2x3x2x2x2xf32>
9+
}
10+
11+
// -----
12+
13+
/// CHECK-LABEL: @gather_single_gather_dim_no_shrink
14+
func.func @gather_single_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32> {
15+
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x1x2x2xf32>
16+
/// CHECK: linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x1x2x2xf32>)
17+
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32>
18+
return %1 : tensor<2x3x2x1x2x2xf32>
19+
}
20+
21+
// -----
22+
23+
/// CHECK-LABEL: @gather_multiple_gather_dim
24+
func.func @gather_multiple_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32> {
25+
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2xf32>
26+
/// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
27+
/// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
28+
/// CHECK: linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<2x3x1xindex>, tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2xf32>)
29+
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32>
30+
return %1 : tensor<2x3x2x2xf32>
31+
}
32+
33+
// -----
34+
35+
/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink
36+
func.func @gather_multiple_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32> {
37+
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x1x1x2xf32>
38+
/// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
39+
/// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
40+
/// CHECK: linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<2x3x1xindex>, tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x1x1x2xf32>)
41+
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32>
42+
return %1 : tensor<2x3x2x1x1x2xf32>
43+
}

0 commit comments

Comments
 (0)