Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ FailureOr<RankedTensorType>
computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector);

/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
/// `from`.
CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
const llvm::SmallBitVector &dropDims);

/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
/// source tensor or inserts the source tensor into a destination tensor with
/// the same shape.
Expand Down
12 changes: 10 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dominance.h"
Expand All @@ -26,6 +27,7 @@
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -254,6 +256,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
<< "\nNot fusable, not an extract_slice op: " << inputTensor);
return failure();
}
llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();

// If producer is already in the same block as consumer, we are done.
if (consumerOpOperand.get().getParentBlock() ==
Expand All @@ -271,12 +274,17 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
consumerOpOperand);

// Replace use.
Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
Type consumerType = consumerOpOperand.get().getType();
// Rank-reduction occurred as part of the extract_slice.
if (cast<ShapedType>(consumerType).getRank() !=
cast<ShapedType>(def.getType()).getRank())
def =
tensor::dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
// Canonicalizations are not guaranteed to have happened before constructing
// `fusedProducer`. In the tensor case this can result in temporary type
// mismatches. Insert a `tensor.cast` op to propagate the transformation
// invariant that types are compatible.
Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
Type consumerType = consumerOpOperand.get().getType();
if (consumerType != def.getType())
def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
consumerOpOperand.set(def);
Expand Down
33 changes: 33 additions & 0 deletions mlir/lib/Dialect/Tensor/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,39 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
return transposedTensorType;
}

/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
/// `from`.
CollapseShapeOp
mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
const llvm::SmallBitVector &dropDims) {
auto fromType = cast<ShapedType>(from.getType());
int64_t rank = fromType.getRank();
assert(rank == static_cast<int64_t>(dropDims.size()) &&
"dropDims dimension does not match from tensor rank");
assert(llvm::all_of(
dropDims.set_bits(),
[&](unsigned dim) { return fromType.getShape()[dim] == 1; }) &&
"Dropping non unit dimension");
// Computed reassociation map for the corresponding tensor.collapse_shape.
SmallVector<ReassociationIndices, 2> reassocMaps;
// Current reassociation group to add dropped dimension to.

int64_t nextDimToGroup = 0;
llvm::SmallBitVector keptDims(dropDims);
keptDims.flip();
int64_t lastSetBit = keptDims.find_last();
for (int64_t setBit : keptDims.set_bits()) {
// Group consecutive dropped dimension with the next non-dropped dimension.
// If this is the last set dimension, also group all subsequent dropped
// dimension, if any.
int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit;
auto seq = llvm::seq_inclusive(nextDimToGroup, upTo);
reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end()));
nextDimToGroup = setBit + 1;
}
return b.create<tensor::CollapseShapeOp>(loc, from, reassocMaps);
}

bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
llvm::SmallBitVector droppedDims = op.getDroppedDims();
int64_t srcDim = 0;
Expand Down
73 changes: 73 additions & 0 deletions mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,76 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
}
return %for0 : tensor<64x128xf32>
}

// -----

#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func @rank_reduced_extract_slice(
%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>,
%arg3: tensor<1x6x6xf32>, %arg4: tensor<4x6xf32>, %arg5: tensor<4x2xf32>
) -> tensor<4x6xf32> {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c6 = arith.constant 6 : index
%0 = linalg.generic
{indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%arg3 : tensor<1x6x6xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%10 = arith.mulf %in, %in_1 : f32
%11 = arith.addf %out, %10 : f32
linalg.yield %11 : f32
} -> tensor<1x6x6xf32>
%1 = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %arg4) -> (tensor<4x6xf32>) {
%2 = tensor.extract_slice %0[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
%3 = linalg.generic
{indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]}
ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%arg5 : tensor<4x2xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%20 = arith.mulf %in, %in_1 : f32
%21 = arith.addf %out, %20 : f32
linalg.yield %21 : f32
} -> tensor<4x2xf32>
%4 = tensor.insert_slice %3 into %arg6[0, %arg7] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
scf.yield %4 : tensor<4x6xf32>
}
return %1 : tensor<4x6xf32>
}

// CHECK: func @rank_reduced_extract_slice(
// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<1x6x5xf32>
// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<1x5x6xf32>
// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
// CHECK-SAME: %[[ARG3:[0-9a-z]*]]: tensor<1x6x6xf32>
// CHECK-SAME: %[[ARG4:[0-9a-z]*]]: tensor<4x6xf32>
// CHECK-SAME: %[[ARG5:[0-9a-z]*]]: tensor<4x2xf32>

// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index

// For loop right after tensor alloc & fill, no linalg.generic.
// CHECK-NOT: linalg.generic
// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[ARG4]])

// Producer linalg.generic now inside the loop, with tiled args sliced before
// it.
// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
// CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[ARG3]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
// CHECK-SAME: outs(%[[PROD_SLICE]] : tensor<1x6x2xf32>)
//
// Consumer uses a rank-reduced version of producer result so a collapse_shape
// is generated.
// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32>
// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
// CHECK-SAME: outs(%[[ARG5]] : tensor<4x2xf32>)
// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
// CHECK: return %[[FOR]] : tensor<4x6xf32>