Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dominance.h"
Expand All @@ -26,6 +27,7 @@
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -235,6 +237,31 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
}

/// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor
/// `from`.
tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from,
const llvm::SmallBitVector &dropDims) {
auto fromType = cast<ShapedType>(from.getType());
assert(fromType.getRank() == dropDims.size());
SmallVector<ReassociationIndices, 2> reassocIdxsVec;
ReassociationIndices reassocIdxs;

bool foundKeptDim = false;
for (int dim = 0; dim < fromType.getRank(); dim++) {
if (!dropDims.test(dim)) {
if (foundKeptDim) {
reassocIdxsVec.push_back(reassocIdxs);
reassocIdxs.clear();
}
foundKeptDim = true;
}
reassocIdxs.push_back(dim);
}
if (!reassocIdxs.empty())
reassocIdxsVec.push_back(reassocIdxs);
return b.create<tensor::CollapseShapeOp>(loc, from, reassocIdxsVec);
}

FailureOr<FusionInfo>
mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
OpOperand &consumerOpOperand) {
Expand All @@ -255,6 +282,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
<< "\nNot fusable, not an extract_slice op: " << inputTensor);
return failure();
}
llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();

// If producer is already in the same block as consumer, we are done.
if (consumerOpOperand.get().getParentBlock() ==
Expand All @@ -272,12 +300,16 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
consumerOpOperand);

// Replace use.
Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
Type consumerType = consumerOpOperand.get().getType();
// Rank-reduction occured as part of the extract_slice.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have just realised ... Why not re-use rank-reducing tensor.extract_slice instead? It's a bit counter-intuitive to make that a rank-reducing tensor.extract_slice is replaced with a pair of tensor.extract_slice + tensor.collapse_shape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extract_slice is done on the producer's arguments while the expand_shape is done on its result. I did look into doing the rank-reducing on the argument since the result type is derived from the DPS init, but the pass applies to both named and generic linalg ops which makes this tricky. In the case of a generic, one of the input might be using the unit dimension in a non-dim affine expression which requires changing the affine maps to update, but since the pass works on the LinalgOp interface there is no way to change the affine map (since named ops have an implicit map). Of course one could do an if but the code becomes quite complex.

My feeling this is the work for another pass. With this one, the producing linalg is moved inside the loop, alongside the consuming generic. Then it's a matter of folding expand shape / extract slice which sounds like a simple linalg fusion that probably already exists.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy for me to close this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Rank-reduction occured as part of the extract_slice.

  • occured -> occurred
  • shouldn't we verify that it's indeed coming from tensor.extract_slice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That function only fuses ExtractSliceOp, see a little above:
`Value inputTensor = consumerOpOperand.get();

// Must be an extract_slice op to guarantee there are loops we can fuse into.
auto sliceOp = inputTensor.getDefiningOptensor::ExtractSliceOp();
if (!sliceOp) {
LLVM_DEBUG(llvm::dbgs()
<< "\nNot fusable, not an extract_slice op: " << inputTensor);
return failure();
}
`

if (cast<ShapedType>(consumerType).getRank() !=
cast<ShapedType>(def.getType()).getRank())
def = collapseTo(b, fusedProducer.getLoc(), def, droppedDims);
// Canonicalizations are not guaranteed to have happened before constructing
// `fusedProducer`. In the tensor case this can result in temporary type
// mismatches. Insert a `tensor.cast` op to propagate the transformation
// invariant that types are compatible.
Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
Type consumerType = consumerOpOperand.get().getType();
if (consumerType != def.getType())
def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
consumerOpOperand.set(def);
Expand Down
63 changes: 63 additions & 0 deletions mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,66 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
}
return %for0 : tensor<64x128xf32>
}

// -----

func.func @rank_reduced_extract_slice(%arg0: tensor<6x6x1x1x1x1xf32>, %arg1: tensor<6x6x1x1xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c6 = arith.constant 6 : index
%cst = arith.constant 0.0 : f32
%init1 = tensor.empty() : tensor<6x6x1x1x1x1xf32>
%fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d6, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<6x6x1x1x1x1xf32>, tensor<6x6x1x1xf32>) outs(%fill1 : tensor<6x6x1x1x1x1xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%10 = arith.mulf %in, %in_1 : f32
%11 = arith.addf %out, %10 : f32
linalg.yield %11 : f32
} -> tensor<6x6x1x1x1x1xf32>
%init2 = tensor.empty() : tensor<4x6xf32>
%1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) {
%2 = tensor.extract_slice %0[0, %arg4, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
%init3 = tensor.empty() : tensor<4x2xf32>
%fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%20 = arith.mulf %in, %in_1 : f32
%21 = arith.addf %out, %20 : f32
linalg.yield %21 : f32
} -> tensor<4x2xf32>
%4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
scf.yield %4 : tensor<4x6xf32>
}
return %1 : tensor<4x6xf32>
}

// CHECK: func @rank_reduced_extract_slice(
// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<6x6x1x1x1x1xf32>
// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<6x6x1x1xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>

// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
// CHECK: %[[FILL_PROD:.*]] = linalg.fill ins({{%.*}} : f32)
// CHECK-SAME: outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
// CHECK: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32>
// CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32>
// CHECK: %[[FILL_CONS:.*]] = linalg.fill ins({{%.*}} : f32)
// CHECK-SAME: outs(%[[EMPTY_CONS]] : tensor<4x2xf32>) -> tensor<4x2xf32>
// CHECK: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]])
// CHECK-DAG: %[[ARG0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[I]], 0, 0] [6, 2, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> to tensor<6x2x1x1xf32>
// CHECK-DAG: %[[FILL_PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>

// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : tensor<6x2x1x1x1x1xf32>, tensor<6x2x1x1xf32>)
// CHECK-SAME: outs(%[[FILL_PROD_SLICE]] : tensor<6x2x1x1x1x1xf32>)
// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32>
// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
// CHECK-SAME: outs(%[[FILL_CONS]] : tensor<4x2xf32>)
// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
// CHECK: return %[[FOR]] : tensor<4x6xf32>