Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 36 additions & 21 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,19 +811,35 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
}

// Create an expanded transpose op.
static Operation *
createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
SmallVector<ReassociationIndices> reassociation,
Value expandedInput, Value output) {
applyPermutationToVector(reassociation, transposeOp.getPermutation());
// the reassociation map is already permuted hence we inverse permute and then
// flatten it. Then we inverse permute it again to get the final expanded
// transpose permutation. For example,
//
// permutation = [2, 0, 1]
// reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
//
// inverse permutation = [1, 2, 0]
// applied to reassocation_map and then flattened becomes
// flatened permutation = [2, 3, 4, 5, 0, 1]
// final permuation is the inverse of the flattened permutation.
//
// Becomes
//
// permutation=[4, 5, 0, 1, 2, 3]

static Operation *createExpandedTransposeOp(PatternRewriter &rewriter,
TransposeOp transposeOp,
Value expandedInput, Value output,
ExpansionInfo &expansionInfo) {
SmallVector<int64_t> newPerm;
for (const auto &reassoc : reassociation) {
for (auto dim : reassoc) {
for (int64_t perm : invertPermutationVector(transposeOp.getPermutation())) {
auto reassoc = expansionInfo.getExpandedDims(perm);
for (int64_t dim : reassoc) {
newPerm.push_back(dim);
}
}
return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
output, newPerm);
output, invertPermutationVector(newPerm));
}

// Create an expanded generic op.
Expand Down Expand Up @@ -857,16 +873,18 @@ static Operation *createExpandedGenericOp(
// Create an expanded fused op that retains the name for certain ops
// such as fill, copy and transpose and produce a generic op for
// rest of linalg ops.
static Operation *createExpandedOp(
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
SmallVector<ReassociationIndices> reassociation) {
static Operation *createExpandedOp(PatternRewriter &rewriter, LinalgOp linalgOp,
TypeRange resultTypes,
ArrayRef<Value> expandedOpOperands,
ArrayRef<Value> outputs,
ArrayRef<AffineMap> expandedOpIndexingMaps,
ExpansionInfo &expansionInfo) {

return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
.Case<TransposeOp>([&](TransposeOp transposeOp) {
return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
expandedOpOperands[0], outputs[0]);
return createExpandedTransposeOp(rewriter, transposeOp,
expandedOpOperands[0], outputs[0],
expansionInfo);
})
.Case<FillOp, CopyOp>([&](Operation *op) {
return clone(rewriter, linalgOp, resultTypes,
Expand Down Expand Up @@ -986,12 +1004,9 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
}

TypeRange resultTypes = ValueRange(outputs).getTypes();
SmallVector<ReassociationIndices> reassociationBeforeExpansion =
isExpanding ? expandingReshapeOp.getReassociationIndices()
: collapsingReshapeOp.getReassociationIndices();
Operation *fusedOp = createExpandedOp(
rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
Operation *fusedOp =
createExpandedOp(rewriter, linalgOp, resultTypes, expandedOpOperands,
outputs, expandedOpIndexingMaps, expansionInfo);
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
Expand Down
75 changes: 45 additions & 30 deletions mlir/test/Dialect/Linalg/reshape_fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,37 @@ func.func @generic_op_reshape_consumer_static(%arg0: tensor<264x4xf32>)
// CHECK-SAME: : tensor<8x33x4xf32>
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[VAL_0]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2]] output_shape [8, 33, 4] : tensor<264x4xf32> into tensor<8x33x4xf32>
// CHECK: %[[T2:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[T0]], %[[CST]] :
// CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>)
// CHECK: return %[[T2]] : tensor<8x33x4xf32>

// -----

func.func @reshape_as_consumer_transpose
(%a : tensor<4x210x6xf32>)
-> tensor<2x3x4x5x6x7xf32> {
%b = tensor.empty() : tensor<6x4x210xf32>
%c = linalg.transpose
ins(%a : tensor<4x210x6xf32>)
outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
%d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
return %d : tensor<2x3x4x5x6x7xf32>
}
// CHECK: func @reshape_as_consumer_transpose
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x210x6xf32>
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
// CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3], [4, 5]] output_shape [4, 5, 6, 7, 2, 3] : tensor<4x210x6xf32> into tensor<4x5x6x7x2x3xf32>
// CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32
// CHECK: %[[T2:.+]] = linalg.transpose ins(%[[T0]] : tensor<4x5x6x7x2x3xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<2x3x4x5x6x7xf32>)
// CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
// CHECK: return %[[T2]] : tensor<2x3x4x5x6x7xf32>


// -----

#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
Expand Down Expand Up @@ -884,37 +907,29 @@ func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,

// -----

func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
%arg1 : tensor<?x?xf32>) ->
tensor<?x?xf32>
{
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
tensor<?x7x?x8xf32> into tensor<?x?xf32>
%1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
return %1 : tensor<?x?xf32>

func.func @reshape_as_producer_transpose
(%a : tensor<4x5x6x7x2x3xf32>)
-> tensor<6x4x210xf32> {
%b = tensor.empty() : tensor<6x4x210xf32>
%c = tensor.collapse_shape %a [[0], [1, 2, 3], [4, 5]] :
tensor<4x5x6x7x2x3xf32> into tensor<4x210x6xf32>
%d = linalg.transpose
ins(%c : tensor<4x210x6xf32>)
outs(%b : tensor<6x4x210xf32>) permutation = [2, 0, 1]
return %d : tensor<6x4x210xf32>
}

// CHECK: func @linalg_transpose_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
// CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
// CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
// CHECK: %[[T2:.+]] = linalg.transpose
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
// CHECK-SAME: permutation = [2, 3, 0, 1]
// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
// CHECK: return %[[T3]]
// CHECK: func @reshape_as_producer_transpose
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<4x5x6x7x2x3xf32>
// CHECK-DAG: %[[INIT:.+]] = tensor.empty()
// CHECK-DAG: %[[T0:.+]] = tensor.expand_shape %[[INIT]] {{\[\[}}0, 1], [2], [3, 4, 5]] output_shape [2, 3, 4, 5, 6, 7] : tensor<6x4x210xf32> into tensor<2x3x4x5x6x7xf32>
// CHECK: %[[T1:.+]] = linalg.transpose ins(%[[ARG0]] : tensor<4x5x6x7x2x3xf32>)
// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xf32>)
// CHECK-SAME: permutation = [4, 5, 0, 1, 2, 3]
// CHECK: %[[T2:.+]] = tensor.collapse_shape %[[T1]] {{\[\[}}0, 1], [2], [3, 4, 5]] : tensor<2x3x4x5x6x7xf32> into tensor<6x4x210xf32>
// CHECK: return %[[T2]] : tensor<6x4x210xf32>


// -----

Expand Down