Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 77 additions & 18 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,77 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
return success();
}

// Create an expanded transpose op.
static Operation *
createExpandedTransposeOp(PatternRewriter &rewriter, TransposeOp transposeOp,
SmallVector<ReassociationIndices> reassociation,
Value expandedInput, Value output) {
applyPermutationToVector(reassociation, transposeOp.getPermutation());
SmallVector<int64_t> newPerm;
for (auto reassoc : reassociation) {
for (auto dim : reassoc) {
newPerm.push_back(dim);
}
}
return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput,
output, newPerm);
}

// Create an expanded generic op.
static Operation *createExpandedGenericOp(
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
ArrayRef<Value> &expandedOpOperands, ArrayRef<Value> outputs,
ExpansionInfo &expansionInfo, ArrayRef<AffineMap> expandedOpIndexingMaps) {
// The iterator types of the expanded op are all parallel.
SmallVector<utils::IteratorType> iteratorTypes(
expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);

for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
for (auto j : expansionInfo.getExpandedDims(i))
iteratorTypes[j] = type;

Operation *fused = rewriter.create<GenericOp>(
linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
expandedOpIndexingMaps, iteratorTypes);

Region &fusedRegion = fused->getRegion(0);
Region &originalRegion = linalgOp->getRegion(0);
rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());

// Update the index accesses after the expansion.
updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
expansionInfo);

return fused;
}

// Create an expanded fused op that retains the name for certain ops
// such as fill, copy and transpose and produce a generic op for
// rest of linalg ops.
static Operation *createExpandedOp(
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
SmallVector<ReassociationIndices> reassociation) {

return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
.Case<TransposeOp>([&](TransposeOp transposeOp) {
return createExpandedTransposeOp(rewriter, transposeOp, reassociation,
expandedOpOperands[0], outputs[0]);
})
.Case<FillOp, CopyOp>([&](Operation *op) {
return clone(rewriter, linalgOp, resultTypes,
llvm::to_vector(llvm::concat<Value>(
llvm::to_vector(expandedOpOperands),
llvm::to_vector(outputs))));
})
.Default([&](Operation *op) {
return createExpandedGenericOp(rewriter, linalgOp, resultTypes,
expandedOpOperands, outputs,
expansionInfo, expandedOpIndexingMaps);
});
}

/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
/// that those conditions have been satisfied.
Expand Down Expand Up @@ -919,25 +990,13 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
}
}

// The iterator types of the expanded op are all parallel.
SmallVector<utils::IteratorType> iteratorTypes(
expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
for (auto j : expansionInfo.getExpandedDims(i))
iteratorTypes[j] = type;

TypeRange resultTypes = ValueRange(outputs).getTypes();
auto fusedOp =
rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
/*inputs=*/expandedOpOperands, outputs,
expandedOpIndexingMaps, iteratorTypes);
Region &fusedRegion = fusedOp->getRegion(0);
Region &originalRegion = linalgOp->getRegion(0);
rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());

// Update the index accesses after the expansion.
updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);

SmallVector<ReassociationIndices> reassociationBeforeExpansion =
isExpanding ? expandingReshapeOp.getReassociationIndices()
: collapsingReshapeOp.getReassociationIndices();
Operation *fusedOp = createExpandedOp(
rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
Expand Down
70 changes: 67 additions & 3 deletions mlir/test/Dialect/Linalg/reshape_fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -783,9 +783,6 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,

// -----

#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
%arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) ->
Expand Down Expand Up @@ -829,6 +826,73 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,

// -----

func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
%arg1 : tensor<?x?xf32>) ->
tensor<?x?xf32>
{
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
tensor<?x7x?x8xf32> into tensor<?x?xf32>
%1 = linalg.copy ins(%0 : tensor<?x?xf32>)
outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// CHECK: func @linalg_copy_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK: %[[C8:.+]] = arith.constant 8 : index
// CHECK: %[[C7:.+]] = arith.constant 7 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index
// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
// CHECK: %[[T2:.+]] = linalg.copy
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<?x7x?x8xf32>)
// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
// CHECK: return %[[T3]]

// -----

func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
%arg1 : tensor<?x?xf32>) ->
tensor<?x?xf32>
{
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
tensor<?x7x?x8xf32> into tensor<?x?xf32>
%1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
return %1 : tensor<?x?xf32>
}

// CHECK: func @linalg_transpose_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
// CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
// CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
// CHECK: %[[T2:.+]] = linalg.transpose
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
// CHECK-SAME: permutation = [2, 3, 0, 1]
// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
// CHECK: return %[[T3]]

// -----

func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
%cst = arith.constant 0 : i32
Expand Down