Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 55 additions & 18 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,55 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
}
return success();
}
// Create an expanded fused op that retains the name for certain ops
// such as fill, copy and transpose and produce a generic op for
// rest of linalg ops.
Operation *createFusedOpForReshapeByExpansion(
PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
SmallVector<ReassociationIndices> reassociation) {

return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation())
.Case<TransposeOp>([&](TransposeOp op) {
applyPermutationToVector(reassociation, op.getPermutation());
SmallVector<int64_t> newPerm;
for (auto reassoc : reassociation) {
for (auto dim : reassoc) {
newPerm.push_back(dim);
}
}
return rewriter.create<TransposeOp>(
linalgOp.getLoc(), expandedOpOperands[0], outputs[0], newPerm);
})
.Case<FillOp, CopyOp>([&](Operation *op) {
return clone(rewriter, linalgOp, resultTypes,
llvm::to_vector(llvm::concat<Value>(
llvm::to_vector(expandedOpOperands),
llvm::to_vector(outputs))));
})
.Default([&](Operation *op) {
// The iterator types of the expanded op are all parallel.
SmallVector<utils::IteratorType> iteratorTypes(
expansionInfo.getExpandedOpNumDims(),
utils::IteratorType::parallel);
for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
for (auto j : expansionInfo.getExpandedDims(i))
iteratorTypes[j] = type;
Operation *fused = rewriter.create<GenericOp>(
linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs,
expandedOpIndexingMaps, iteratorTypes);
Region &fusedRegion = fused->getRegion(0);
Region &originalRegion = linalgOp->getRegion(0);
rewriter.cloneRegionBefore(originalRegion, fusedRegion,
fusedRegion.begin());

// Update the index accesses after the expansion.
updateExpandedGenericOpRegion(rewriter, linalgOp.getLoc(), fusedRegion,
expansionInfo);
return fused;
});
}

/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
/// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes
Expand Down Expand Up @@ -919,25 +968,13 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
}
}

// The iterator types of the expanded op are all parallel.
SmallVector<utils::IteratorType> iteratorTypes(
expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel);
for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray()))
for (auto j : expansionInfo.getExpandedDims(i))
iteratorTypes[j] = type;

TypeRange resultTypes = ValueRange(outputs).getTypes();
auto fusedOp =
rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes,
/*inputs=*/expandedOpOperands, outputs,
expandedOpIndexingMaps, iteratorTypes);
Region &fusedRegion = fusedOp->getRegion(0);
Region &originalRegion = linalgOp->getRegion(0);
rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin());

// Update the index accesses after the expansion.
updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo);

SmallVector<ReassociationIndices> reassociationBeforeExpansion =
isExpanding ? expandingReshapeOp.getReassociationIndices()
: collapsingReshapeOp.getReassociationIndices();
Operation *fusedOp = createFusedOpForReshapeByExpansion(
rewriter, linalgOp, resultTypes, expandedOpOperands, outputs,
expandedOpIndexingMaps, expansionInfo, reassociationBeforeExpansion);
// Reshape the result values to their original shape if this is a collapsing
// reshape folded into its consumer.
SmallVector<Value> resultVals;
Expand Down
70 changes: 67 additions & 3 deletions mlir/test/Dialect/Linalg/reshape_fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -783,9 +783,6 @@ func.func @linalg_add_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,

// -----

#map0 = affine_map<(d0, d1, d2) -> (d2, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d2)>
func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
%arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) ->
Expand Down Expand Up @@ -829,6 +826,73 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,

// -----

func.func @linalg_copy_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
%arg1 : tensor<?x?xf32>) ->
tensor<?x?xf32>
{
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
tensor<?x7x?x8xf32> into tensor<?x?xf32>
%1 = linalg.copy ins(%0 : tensor<?x?xf32>)
outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// CHECK: func @linalg_copy_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK: %[[C8:.+]] = arith.constant 8 : index
// CHECK: %[[C7:.+]] = arith.constant 7 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM]], %[[C7]] : index
// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_0]], %[[C8]] : index
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_0]], 7, %[[VAL_1]], 8] : tensor<?x?xf32> into tensor<?x7x?x8xf32>
// CHECK: %[[T2:.+]] = linalg.copy
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<?x7x?x8xf32>)
// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: tensor<?x7x?x8xf32> into tensor<?x?xf32>
// CHECK: return %[[T3]]

// -----

func.func @linalg_transpose_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
%arg1 : tensor<?x?xf32>) ->
tensor<?x?xf32>
{
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] :
tensor<?x7x?x8xf32> into tensor<?x?xf32>
%1 = linalg.transpose ins(%0 : tensor<?x?xf32>)
outs(%arg1 : tensor<?x?xf32>) permutation = [1, 0]
return %1 : tensor<?x?xf32>
}

// CHECK: func @linalg_transpose_reshape_producer_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x7x?x8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
// CHECK-DAG: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C7]] : index
// CHECK-DAG: %[[VAL_1:.+]] = arith.divsi %[[DIM]], %[[C8]] : index
// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 8, %[[VAL_0]], 7] : tensor<?x?xf32> into tensor<?x8x?x7xf32>
// CHECK: %[[T2:.+]] = linalg.transpose
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x7x?x8xf32>)
// CHECK-SAME: outs(%[[T1]] : tensor<?x8x?x7xf32>)
// CHECK-SAME: permutation = [2, 3, 0, 1]
// CHECK: %[[T3:.+]] = tensor.collapse_shape %[[T2]]
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: tensor<?x8x?x7xf32> into tensor<?x?xf32>
// CHECK: return %[[T3]]

// -----

func.func @fuse_by_expanding_pad(%arg0 : tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] : tensor<2x3x4x5x6x7x8x9xi32> into tensor<2x12x5x336x9xi32>
%cst = arith.constant 0 : i32
Expand Down