Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion include/gc/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def MergeAlloc : Pass<"gc-merge-alloc", "func::FuncOp"> {
lifetime of the original memref before merging. This pass schedules the
offsets to 1) make sure the offsets and address ranges do not overlap if
two "mergeable" allocations have overlapped lifetime, and 2) reuse the
address ranges that are considered "hot" in cache for an later allocation.
address ranges that are considered "hot" in cache for an later allocation.
}];
let options = [
Option<"optionAnalysisOnly", "analysis-only", "bool",
Expand Down Expand Up @@ -231,6 +231,16 @@ def FoldTensorOperation : Pass<"fold-tensor-operation"> {
let description = [{
Remove some useless tensor operations.
}];
let dependentDialects = [
"tensor::TensorDialect",
];
}

def DecomposeTensorOperation : Pass<"decompose-tensor-operation"> {
let summary = "decompose some tensor operation";
let description = [{
Decompose some tensor operations(concat, gather) into linalg operation.
}];
let dependentDialects = [
"::mlir::tensor::TensorDialect"
];
Expand Down
1 change: 1 addition & 0 deletions lib/gc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ gc_add_mlir_library(GcPasses
MergeAlloc.cpp
MergeAllocTickBased.cpp
FoldTensorOperation.cpp
DecomposeTensorOperation.cpp
LowerToTileVector.cpp

DEPENDS
Expand Down
181 changes: 181 additions & 0 deletions lib/gc/Transforms/DecomposeTensorOperation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
//===-- DecomposeTensorOperation.cpp - DESC ---------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "gc/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Casting.h"

namespace mlir {
namespace gc {

#define GEN_PASS_DEF_DECOMPOSETENSOROPERATION
#include "gc/Transforms/Passes.h.inc"
namespace {

/// Decompose `tensor.gather` into `linalg.generic`.
///
/// %2 = tensor.gather %0[%1] gather_dims([0]) : (tensor<7x128xf16>,
/// tensor<1x7x1xindex>) -> tensor<1x7x128xf16>
///
/// Becomes
///
/// %empty = tensor.empty() : tensor<1x7x128xf16>
/// %14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1,
/// 0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types =
/// ["parallel", "parallel", "parallel"]} ins(%expanded : tensor<1x7x1xindex>)
/// outs(%13 : tensor<1x7x128xf16>) {
/// ^bb0(%in: index, %out: f16):
/// %17 = linalg.index 2 : index
/// %extracted = tensor.extract %0[%in, %17] : tensor<7x128xf16>
/// linalg.yield %extracted : f16
/// } -> tensor<1x7x128xf16>
struct DecomposeGatherOp : public OpRewritePattern<tensor::GatherOp> {
using OpRewritePattern<tensor::GatherOp>::OpRewritePattern;

SmallVector<OpFoldResult> getDstMixedSizes(PatternRewriter &rewriter,
Location loc,
tensor::GatherOp gatherOp) const {
SmallVector<OpFoldResult> dstSize =
tensor::getMixedSizes(rewriter, loc, gatherOp.getResult());
SmallVector<OpFoldResult> indexSize =
tensor::getMixedSizes(rewriter, loc, gatherOp.getIndices());
SmallVector<OpFoldResult> srcSize =
tensor::getMixedSizes(rewriter, loc, gatherOp.getSource());
SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());
bool isShrinkDst = (indexSize.size() - 1) + srcSize.size() ==
dstSize.size() + gatherDims.size();
for (size_t i = 0; i < indexSize.size() - 1; i++) {
dstSize[i] = indexSize[i];
}
auto cnt = 0;
for (size_t i = indexSize.size() - 1; i < dstSize.size(); i++) {
while (isShrinkDst && llvm::find(gatherDims, cnt) != gatherDims.end()) {
cnt++;
}
dstSize[i] = llvm::find(gatherDims, cnt) == gatherDims.end()
? srcSize[cnt]
: getAsIndexOpFoldResult(rewriter.getContext(), 1);
cnt++;
}
return dstSize;
}

LogicalResult matchAndRewrite(tensor::GatherOp gatherOp,
PatternRewriter &rewriter) const override {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(gatherOp);
Location loc = gatherOp.getLoc();
SmallVector<int64_t> gatherDims(gatherOp.getGatherDims());

// create destination tensor for linalg out
RankedTensorType dstType = gatherOp.getResultType();
Value dstTensor = rewriter.create<tensor::EmptyOp>(
loc, getDstMixedSizes(rewriter, loc, gatherOp),
dstType.getElementType());

// split index tensor to create the linalg input
SmallVector<Value> indexTensors;
Value originIndexTensor = gatherOp.getIndices();
SmallVector<OpFoldResult> indexTensorSize =
tensor::getMixedSizes(rewriter, loc, originIndexTensor);
SmallVector<OpFoldResult> indexTensorStride(
indexTensorSize.size(),
getAsIndexOpFoldResult(rewriter.getContext(), 1));
SmallVector<OpFoldResult> indexTensorOffset(
indexTensorSize.size(),
getAsIndexOpFoldResult(rewriter.getContext(), 0));
indexTensorSize[indexTensorSize.size() - 1] =
getAsIndexOpFoldResult(rewriter.getContext(), 1);

for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
indexTensorOffset[indexTensorSize.size() - 1] =
getAsIndexOpFoldResult(rewriter.getContext(), cnt);
Value indexTensor = rewriter.create<tensor::ExtractSliceOp>(
loc, originIndexTensor, indexTensorOffset, indexTensorSize,
indexTensorStride);
indexTensors.emplace_back(indexTensor);
}

// create the affine map
SmallVector<AffineMap> affineMaps;
SmallVector<AffineExpr> dimExprs;
size_t dstRank = dstType.getShape().size();
for (unsigned i = 0; i < indexTensorSize.size() - 1; ++i)
dimExprs.push_back(rewriter.getAffineDimExpr(i));
dimExprs.push_back(getAffineConstantExpr(0, rewriter.getContext()));

for (size_t cnt = 0; cnt < gatherDims.size(); cnt++) {
AffineMap currentMap =
AffineMap::get(/*dimCount=*/dstRank, /*symbolCount=*/0, dimExprs,
rewriter.getContext());
affineMaps.emplace_back(currentMap);
}
affineMaps.emplace_back(rewriter.getMultiDimIdentityMap(dstRank));

// create iterater types array
SmallVector<utils::IteratorType> iteratorTypesArray(
dstRank, utils::IteratorType::parallel);

// check whether the gather op is valid
size_t srcRank = gatherOp.getSourceType().getShape().size();
assert(((indexTensorSize.size() - 1) + srcRank == dstRank ||
(indexTensorSize.size() - 1) + srcRank ==
dstRank + gatherDims.size()) &&
"Expected: index_size - 1 + source_size == dst_size or dst_szie - "
"gather_size. \n");
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
gatherOp, TypeRange(dstType), indexTensors, ValueRange{dstTensor},
affineMaps, iteratorTypesArray,
[&](OpBuilder &b, Location loc, ValueRange args) {
SmallVector<Value> indexValues(srcRank);
bool isShrinkDst = (indexTensorSize.size() - 1) + srcRank ==
dstRank + gatherDims.size();
int cnt = 0;
for (auto i = indexTensorSize.size() - 1; i < dstRank; i++) {
while (isShrinkDst &&
llvm::find(gatherDims, cnt) != gatherDims.end()) {
cnt++;
}
indexValues[cnt] = b.create<linalg::IndexOp>(loc, i);
cnt++;
}
for (auto &&[i, dim] : llvm::enumerate(gatherDims)) {
indexValues[dim] = args[i];
}

Value extract = b.create<tensor::ExtractOp>(loc, gatherOp.getSource(),
indexValues);
b.create<linalg::YieldOp>(loc, extract);
});
return success();
}
};

/// DecomposeTensorOperationPass is a pass that decompose some tensor
/// operations like tensor.gather, tensor.concat.
struct DecomposeTensorOperationPass
: public impl::DecomposeTensorOperationBase<DecomposeTensorOperationPass> {
void runOnOperation() final {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);

patterns.add<DecomposeGatherOp>(patterns.getContext());
tensor::populateDecomposeTensorConcatPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
} // namespace gc
} // namespace mlir
1 change: 1 addition & 0 deletions lib/gc/Transforms/GPU/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void populateGPUPipeline(OpPassManager &pm,
pm.addNestedPass<func::FuncOp>(createAddContextArg());
}

pm.addPass(createDecomposeTensorOperation());
pm.addNestedPass<func::FuncOp>(createGpuTilingAndFusion());

pm.addPass(bufferization::createEmptyTensorEliminationPass());
Expand Down
72 changes: 72 additions & 0 deletions test/mlir/test/gc/Transforms/DecomposeTensorOperation.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// RUN: gc-opt %s -decompose-tensor-operation --split-input-file | FileCheck %s

/// CHECK-LABEL: @gather_single_gather_dim
func.func @gather_single_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32> {
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2x2xf32>
/// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2x2xf32>)
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x2x2xf32>
return %1 : tensor<2x3x2x2x2xf32>
}

// -----

/// CHECK-LABEL: @gather_single_gather_dim_no_shrink
func.func @gather_single_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32> {
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x1x2x2xf32>
/// CHECK: linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG1:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x1x2x2xf32>)
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<2x2x2x2xf32>, tensor<2x3x1xindex>) -> tensor<2x3x2x1x2x2xf32>
return %1 : tensor<2x3x2x1x2x2xf32>
}

// -----

/// CHECK-LABEL: @gather_multiple_gather_dim
func.func @gather_multiple_gather_dim(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32> {
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x2xf32>
/// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
/// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
/// CHECK: linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<2x3x1xindex>, tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x2xf32>)
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x2xf32>
return %1 : tensor<2x3x2x2xf32>
}

// -----

/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink
func.func @gather_multiple_gather_dim_no_shrink(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32> {
/// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x3x2x1x1x2xf32>
/// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
/// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [2, 3, 1] [1, 1, 1] : tensor<2x3x2xindex> to tensor<2x3x1xindex>
/// CHECK: linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<2x3x1xindex>, tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x2x1x1x2xf32>)
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<2x3x2xindex>) -> tensor<2x3x2x1x1x2xf32>
return %1 : tensor<2x3x2x1x1x2xf32>
}

// -----

/// CHECK-LABEL: @gather_single_gather_dim_dynamic
func.func @gather_single_gather_dim_dynamic(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32> {
/// CHECK: %[[DIM1:.*]] = tensor.dim
/// CHECK: %[[DIM2:.*]] = tensor.dim
/// CHECK: %[[DIM3:.*]] = tensor.dim
/// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]], %[[DIM3:.*]]) : tensor<2x3x?x?x?xf32>
/// CHECK: linalg.generic {{.*}} iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0:.*]] : tensor<2x3x1xindex>) outs(%[[EMPTY:.*]] : tensor<2x3x?x?x?xf32>)
%1 = tensor.gather %arg0[%arg1] gather_dims([1]) : (tensor<?x?x?x?xf32>, tensor<2x3x1xindex>) -> tensor<2x3x?x?x?xf32>
return %1 : tensor<2x3x?x?x?xf32>
}

// -----

/// CHECK-LABEL: @gather_multiple_gather_dim_no_shrink_dynamic
func.func @gather_multiple_gather_dim_no_shrink_dynamic(%arg0: tensor<2x2x2x2xf32>, %arg1: tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32> {
/// CHECK: %[[DIM1:.*]] = tensor.dim
/// CHECK: %[[DIM2:.*]] = tensor.dim
/// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM1:.*]], %[[DIM2:.*]]) : tensor<?x?x2x1x1x2xf32>
/// CHECK: %[[DIM3:.*]] = tensor.dim
/// CHECK: %[[DIM4:.*]] = tensor.dim
/// CHECK: %[[EXTRACTSLICE1:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 0] [%[[DIM3:.*]], %[[DIM4:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
/// CHECK: %[[EXTRACTSLICE2:.*]] = tensor.extract_slice %[[ARG1:.*]][0, 0, 1] [%[[DIM3:.*]], %[[DIM4:.*]], 1] [1, 1, 1] : tensor<?x?x2xindex> to tensor<?x?x1xindex>
/// CHECK: linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%[[EXTRACTSLICE1:.*]], %[[EXTRACTSLICE2:.*]] : tensor<?x?x1xindex>, tensor<?x?x1xindex>) outs(%[[EMPTY:.*]] : tensor<?x?x2x1x1x2xf32>)
%1 = tensor.gather %arg0[%arg1] gather_dims([1, 2]) : (tensor<2x2x2x2xf32>, tensor<?x?x2xindex>) -> tensor<?x?x2x1x1x2xf32>
return %1 : tensor<?x?x2x1x1x2xf32>
}