Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,54 @@ struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
AffineMap map = op.getMatchingIndexingMap(operand);
auto transposeOp = operand->get().getDefiningOp<TransposeOp>();

if (!map.isIdentity() || !transposeOp) {
if (!transposeOp) {
// push in original operand and its map.
newIns.push_back(operand->get());
newMaps.push_back(map);
continue;
}
newIns.push_back(transposeOp.getInput());
// push in transposeOp's inverse permutation map.
newMaps.push_back(transposeOp.getMatchingIndexingMap(
transposeOp.getDpsInputOperand(0)));
// push in composed affine map.
newMaps.push_back(
transposeOp.getMatchingIndexingMap(transposeOp.getDpsInputOperand(0))
.compose(map));
changed = true;
}
if (!changed)
return failure();
newMaps.push_back(op.getIndexingMapsArray().back());

rewriter.replaceOpWithNewOp<ElementwiseOp>(
op, newIns, op.getDpsInits()[0], op.getKindAttr(),
rewriter.getAffineMapArrayAttr(newMaps));
return success();
}
};

struct FoldBroadcastPattern : public OpRewritePattern<ElementwiseOp> {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the logic in the transpose and broadcast pattern is identical -- could you make the op class a template parameter, or alternatively handle both in the same pattern?

using OpRewritePattern<ElementwiseOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ElementwiseOp op,
PatternRewriter &rewriter) const override {
bool changed = false;
SmallVector<Value> newIns;
SmallVector<AffineMap> newMaps;
for (OpOperand *operand : op.getDpsInputOperands()) {
AffineMap map = op.getMatchingIndexingMap(operand);
auto broadcastOp = operand->get().getDefiningOp<BroadcastOp>();

if (!broadcastOp) {
// push in original operand and its map.
newIns.push_back(operand->get());
newMaps.push_back(map);
continue;
}

newIns.push_back(broadcastOp.getInput());
// push in composed affine map.
newMaps.push_back(
broadcastOp.getMatchingIndexingMap(broadcastOp.getDpsInputOperand(0))
.compose(map));
changed = true;
}
if (!changed)
Expand Down Expand Up @@ -84,4 +122,5 @@ struct LinalgFoldIntoElementwisePass
void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
RewritePatternSet &patterns) {
patterns.add<FoldTransposePattern>(patterns.getContext());
patterns.add<FoldBroadcastPattern>(patterns.getContext());
}
162 changes: 155 additions & 7 deletions mlir/test/Dialect/Linalg/elementwise/fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
// CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
//
func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
func.func @unary_transpose(%A: tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%empty = tensor.empty() : tensor<8x16x32xf32>
%transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
%transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
ins(%transposed_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %result : tensor<8x16x32xf32>
}

Expand All @@ -28,16 +28,164 @@ func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) ->
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
//
func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
func.func @binary_transposed(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>

%empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
%transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
%transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}

// -----

// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
//
// CHECK: func.func @unary_broadcasted(%[[A:.+]]: tensor<8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
// CHECK-SAME: indexing_maps = [#[[BROADCASTED]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]] : tensor<8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
//
func.func @unary_broadcasted(%A: tensor<8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%empty = tensor.empty() : tensor<8x16x32xf32>
%broadcasted_A = linalg.broadcast ins(%A : tensor<8x32xf32>) outs(%empty : tensor<8x16x32xf32>) dimensions = [1]
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %result : tensor<8x16x32xf32>
}

// -----

// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1) -> (d0)>
//
// CHECK: func.func @binary_broadcasted(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[BROADCASTED]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
//
func.func @binary_broadcasted(%A: tensor<?x?xf32>, %B: tensor<?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>

%empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
%broadcasted_B = linalg.broadcast ins(%B : tensor<?xf32>) outs(%empty : tensor<?x?xf32>) dimensions = [1]
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
ins(%A, %broadcasted_B : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}

// -----

// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
//
// CHECK: func.func @fold_broadcast_after_transpose_fold(%[[A:.+]]: tensor<16xf32>, %[[B:.+]]: tensor<16x32xf32>) -> tensor<16x32xf32> {
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]] : tensor<16xf32>) outs(%[[B]] : tensor<16x32xf32>) -> tensor<16x32xf32>
// CHECK-NEXT: return %[[RES]] : tensor<16x32xf32>
//
func.func @fold_broadcast_after_transpose_fold(%A: tensor<16xf32>, %B: tensor<16x32xf32>) -> tensor<16x32xf32> {
%empty_b = tensor.empty() : tensor<32x16xf32>
%broadcasted_A = linalg.broadcast ins(%A : tensor<16xf32>) outs(%empty_b : tensor<32x16xf32>) dimensions = [0]

%empty_t = tensor.empty() : tensor<16x32xf32>
%transposed_A = linalg.transpose ins(%broadcasted_A : tensor<32x16xf32>) outs(%empty_t : tensor<16x32xf32>) permutation = [1, 0]

%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
ins(%transposed_A : tensor<16x32xf32>) outs(%B : tensor<16x32xf32>) -> tensor<16x32xf32>
return %result : tensor<16x32xf32>
}

// -----

// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
//
// CHECK: func.func @fold_transpose_after_broadcast_fold(%[[A:.+]]: tensor<32x16xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]] : tensor<32x16xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
//
func.func @fold_transpose_after_broadcast_fold(%A: tensor<32x16xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
%empty_t = tensor.empty() : tensor<16x32xf32>
%transposed_A = linalg.transpose ins(%A : tensor<32x16xf32>) outs(%empty_t : tensor<16x32xf32>) permutation = [1, 0]

%empty_b = tensor.empty() : tensor<8x16x32xf32>
%broadcasted_A = linalg.broadcast ins(%transposed_A : tensor<16x32xf32>) outs(%empty_b : tensor<8x16x32xf32>) dimensions = [0]

%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
ins(%broadcasted_A : tensor<8x16x32xf32>) outs(%B : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
return %result : tensor<8x16x32xf32>
}

// -----

// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
//
// CHECK: func.func @fold_broadcast_after_transpose_fold_binary(%[[A:.+]]: tensor<?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
//
func.func @fold_broadcast_after_transpose_fold_binary(%A: tensor<?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %B, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %B, %c1 : tensor<?x?xf32>

%empty_b = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
%broadcasted_A = linalg.broadcast ins(%A : tensor<?xf32>) outs(%empty_b : tensor<?x?xf32>) dimensions = [0]

%empty_t = tensor.empty(%dim0, %dim1) : tensor<?x?xf32>
%transposed_A = linalg.transpose ins(%broadcasted_A : tensor<?x?xf32>) outs(%empty_t : tensor<?x?xf32>) permutation = [1, 0]

%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
ins(%transposed_A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
return %result : tensor<?x?xf32>
}

// -----

// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[COMPOSED_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
//
// CHECK: func.func @fold_transpose_after_broadcast_fold_binary(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?x?xf32>, %[[C:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
// CHECK-SAME: indexing_maps = [#[[COMPOSED_MAP]], #[[IDENTITY]], #[[IDENTITY]]]
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?x?xf32>) outs(%[[C]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK-NEXT: return %[[RES]] : tensor<?x?x?xf32>
//
func.func @fold_transpose_after_broadcast_fold_binary(%A: tensor<?x?xf32>, %B: tensor<?x?x?xf32>, %C: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%dim0 = tensor.dim %B, %c0 : tensor<?x?x?xf32>
%dim1 = tensor.dim %B, %c1 : tensor<?x?x?xf32>
%dim2 = tensor.dim %B, %c2 : tensor<?x?x?xf32>

%empty_t = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
%transposed_A = linalg.transpose ins(%A : tensor<?x?xf32>) outs(%empty_t : tensor<?x?xf32>) permutation = [1, 0]

%empty_b = tensor.empty(%dim0, %dim1, %dim2) : tensor<?x?x?xf32>
%broadcasted_A = linalg.broadcast ins(%transposed_A : tensor<?x?xf32>) outs(%empty_b : tensor<?x?x?xf32>) dimensions = [0]

%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
ins(%broadcasted_A, %B : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%C : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %result : tensor<?x?x?xf32>
}
Loading