Skip to content

Commit 20b25f3

Browse files
committed
format and add test
1 parent fa60fa5 commit 20b25f3

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
109109
/// * There is a chance that the implementation of the transformation does not
110110
/// agree with the result of this method. This function gives a prediction based
111111
/// on an optimized fusion.
112-
llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults(
113-
LinalgOp producer, LinalgOp consumer, OpOperand *fusedOperand) {
112+
llvm::SmallDenseSet<int>
113+
mlir::linalg::getPreservedProducerResults(LinalgOp producer, LinalgOp consumer,
114+
OpOperand *fusedOperand) {
114115
llvm::SmallDenseSet<int> preservedProducerResults;
115116
llvm::SmallVector<OpOperand *> opOperandsToIgnore;
116117

@@ -416,14 +417,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
416417
}
417418

418419
// Generate the fused op.
419-
// auto fusedOp = cloneWithoutRegions(rewriter, consumer,
420-
// fusedResultTypes, fusedInputOperands);
421-
// fusedOp.setIndexingMapsAttr(idxMap);
422-
// fusedOp.setIteratorTypesAttr(itTp);
423420
auto fusedOp = rewriter.create<GenericOp>(
424421
consumer.getLoc(), fusedResultTypes, fusedInputOperands,
425-
fusedOutputOperands, fusedIndexMaps,
426-
consumer.getIteratorTypesArray());
422+
fusedOutputOperands, fusedIndexMaps, consumer.getIteratorTypesArray());
427423
if (!fusedOp.getShapesToLoopsMap()) {
428424
// Fused op has invalid indexing maps. Typically this means something is off
429425
// in the input, but going ahead here would result in verification errors.

mlir/test/Dialect/Linalg/fusion-elementwise.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,24 @@ func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) ->
5959
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
6060
// CHECK-SAME: outs(%[[EMPTY]] :
6161
// CHECK-NOT: linalg.generic
62+
63+
// -----
64+
65+
func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
66+
%fill = tensor.empty() : tensor<8xf32>
67+
%add = linalg.map {arith.addf} ins(%in1, %in2: tensor<8xf32>, tensor<8xf32>) outs(%fill: tensor<8xf32>)
68+
%mapped_65 = linalg.map { math.sqrt } ins(%add : tensor<8xf32>) outs(%fill : tensor<8xf32>)
69+
return %mapped_65 : tensor<8xf32>
70+
}
71+
72+
// CHECK-LABEL: func @map_ops
73+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
74+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
75+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
76+
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
77+
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.}}) outs(%[[EMPTY]] :
78+
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
79+
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
80+
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
81+
// CHECK-NEXT: linalg.yield %[[SQRT]]
82+
// CHECK-NOT: linalg.generic

0 commit comments

Comments
 (0)