Skip to content

Commit 5280b87

Browse files
committed
fix bug with no output bb args and add test
1 parent d723913 commit 5280b87

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,31 @@ static void generateFusedElementwiseOpRegion(
222222
auto producer = cast<LinalgOp>(fusedOperand->get().getDefiningOp());
223223
auto consumer = cast<LinalgOp>(fusedOperand->getOwner());
224224
// Build the region of the fused op.
225+
226+
// Since some ops, like `linalg.map`, do not have block arguments for init operands
227+
// then we first "generalize" the block by adding arguments for init operands when
228+
// they aren't present. We detect this case by checking if
229+
// `getOpOperandsMatchingBBargs() == getDpsInputOperands();
225230
Block &producerBlock = producer->getRegion(0).front();
231+
if (producer.getOpOperandsMatchingBBargs() ==
232+
producer.getDpsInputOperands()) {
233+
for (auto init : producer.getDpsInits()) {
234+
Type bbType = isa<ShapedType>(init.getType())
235+
? cast<ShapedType>(init.getType()).getElementType()
236+
: init.getType();
237+
producerBlock.addArgument(bbType, producer.getLoc());
238+
}
239+
}
226240
Block &consumerBlock = consumer->getRegion(0).front();
241+
if (consumer.getOpOperandsMatchingBBargs() ==
242+
consumer.getDpsInputOperands()) {
243+
for (auto init : consumer.getDpsInits()) {
244+
Type bbType = isa<ShapedType>(init.getType())
245+
? cast<ShapedType>(init.getType()).getElementType()
246+
: init.getType();
247+
consumerBlock.addArgument(bbType, consumer.getLoc());
248+
}
249+
}
227250
OpBuilder::InsertionGuard guard(rewriter);
228251
Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));
229252
IRMapping mapper;

mlir/test/Dialect/Linalg/fusion-elementwise.mlir

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,37 @@ func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
7979
// CHECK-NEXT: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]]
8080
// CHECK-NEXT: %[[SQRT:.*]] = math.sqrt %[[ADD]]
8181
// CHECK-NEXT: linalg.yield %[[SQRT]]
82-
// CHECK-NOT: linalg.generic
82+
// CHECK-NOT: linalg.map
83+
84+
// -----
85+
86+
func.func @map_ops_mixed_types(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
87+
%init = tensor.empty() : tensor<8xi1>
88+
%initf = tensor.empty() : tensor<8xf32>
89+
%0 = linalg.map {math.sqrt} ins(%arg0 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
90+
%1 = linalg.map {math.exp} ins(%arg1 : tensor<8xf32>) outs(%initf : tensor<8xf32>)
91+
%2 = linalg.map ins(%0, %1 : tensor<8xf32>, tensor<8xf32>) outs (%init : tensor<8xi1>)
92+
(%in0 : f32, %in1 : f32) {
93+
%cmp = arith.cmpf olt, %in0, %in1 : f32
94+
linalg.yield %cmp : i1
95+
}
96+
%3 = linalg.map { arith.select } ins(%2, %0, %1 : tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) outs(%initf : tensor<8xf32>)
97+
return %3 : tensor<8xf32>
98+
}
99+
100+
// CHECK-LABEL: func @map_ops_mixed_types
101+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
102+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
103+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
104+
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
105+
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] :
106+
// CHECK-NEXT: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
107+
// CHECK-NEXT: %[[EXP0:.*]] = math.exp %[[IN1]]
108+
// CHECK-NEXT: %[[SQRT0:.*]] = math.sqrt %[[IN0]]
109+
// CHECK-NEXT: %[[EXP1:.*]] = math.exp %[[IN1]]
110+
// CHECK-NEXT: %[[SQRT1:.*]] = math.sqrt %[[IN0]]
111+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf olt, %[[SQRT1]], %[[EXP1]]
112+
// CHECK-NEXT: %[[RES:.*]] = arith.select %[[CMP]], %[[SQRT0]], %[[EXP0]]
113+
// CHECK-NEXT: linalg.yield %[[RES]]
114+
// CHECK-NOT: linalg.map
115+

0 commit comments

Comments
 (0)