Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -927,17 +927,43 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
iteratorTypes[j] = type;

TypeRange resultTypes = ValueRange(outputs).getTypes();
auto fusedOp =
rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
/*inputs=*/expandedOpOperands, outputs,
expandedOpIndexingMaps, iteratorTypes);
Region &fusedRegion = fusedOp->getRegion(0);
Region &originalRegion = linalgOp->getRegion(0);
rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());

// Update the index accesses after the expansion.
updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);

Operation *fusedOp;

TypeSwitch<Operation *>(linalgOp.getOperation())
.Case<GenericOp>([&](GenericOp op) {
fusedOp = rewriter.create<GenericOp>(
linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
expandedOpIndexingMaps, iteratorTypes);
Region &fusedRegion = fusedOp->getRegion(0);
Region &originalRegion = linalgOp->getRegion(0);
rewriter.cloneRegionBefore(originalRegion, fusedRegion,
fusedRegion.begin());

// Update the index accesses after the expansion.
updateExpandedGenericOpRegion(rewriter, loc, fusedRegion,
expansionInfo);
})
.Case<TransposeOp>([&](TransposeOp op) {
SmallVector<ReassociationIndices> reassociation =
isExpanding ? expandingReshapeOp.getReassociationIndices()
: collapsingReshapeOp.getReassociationIndices();
applyPermutationToVector(reassociation, op.getPermutation());
SmallVector<int64_t> newPerm;
for (auto reassoc : reassociation) {
for (auto dim : reassoc) {
newPerm.push_back(dim);
}
}
fusedOp = rewriter.create<TransposeOp>(
linalgOp.getLoc(), expandedOpOperands[0], outputs[0], newPerm);
})
// All other expandable linalg ops that are not generic or transpose can
// be cloned with the expanded input and output operands.
.Default([&](Operation *op) {
fusedOp = clone(
rewriter, linalgOp, resultTypes,
llvm::to_vector(llvm::concat<Value>(expandedOpOperands, outputs)));
});
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
Expand Down
49 changes: 38 additions & 11 deletions mlir/test/Dialect/Linalg/reshape_fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,6 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
return %1 : tensor<?x?x4x5xf32>
}

// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @linalg_add_reshape_consumer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
Expand All @@ -774,18 +773,13 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG2]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
// CHECK: %[[T3:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], %[[VAL_2]], 4, 5] : tensor<?x?xf32> into tensor<?x?x4x5xf32>
// CHECK: %[[T4:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
// CHECK: %[[T4:.+]] = linalg.add
// CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<?x?x4x5xf32>, tensor<?x?x4x5xf32>)
// CHECK-SAME: outs(%[[T3]] : tensor<?x?x4x5xf32>)
// CHECK: return %[[T4]] : tensor<?x?x4x5xf32>

// -----

#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
%arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) ->
Expand All @@ -798,7 +792,6 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
return %1 : tensor<?x?xf32>
}

// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @linalg_add_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
Expand All @@ -817,16 +810,50 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_1]], %[[C7]] : index
// CHECK: %[[VAL_3:.+]] = arith.divsi %[[DIM_2]], %[[C8]] : index
// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG2]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_2]], 7, %[[VAL_3]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
// CHECK: %[[T3:.+]] = linalg.add
// CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor<?x7x?x8xf32>, tensor<?x7x?x8xf32>)
// CHECK-SAME: outs(%[[T2]] : tensor<?x7x?x8xf32>)
// CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]]
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
// CHECK: return %[[T4]]

// -----

func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
%arg1 : tensor<?x?xf32>) ->
tensor<?x?xf32>
{
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
tensor<?x7x?x8xf32> into tensor<?x?xf32>
%1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
return %1 : tensor<?x?xf32>
}

// CHECK: func @linalg_transpose_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
// CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
// CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
// CHECK: %[[T2:.+]] = linalg.transpose
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
// CHECK-SAME: permutation = [2, 3, 0, 1]
// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
// CHECK: return %[[T3]]



// -----

func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
Expand Down