Skip to content

Commit 58582bf

Browse files
committed
use getElementTypeOrSelf for cleanup
1 parent 8d2e8e0 commit 58582bf

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -230,22 +230,16 @@ static void generateFusedElementwiseOpRegion(
230230
Block &producerBlock = producer->getRegion(0).front();
231231
if (producer.getOpOperandsMatchingBBargs() ==
232232
producer.getDpsInputOperands()) {
233-
for (auto init : producer.getDpsInits()) {
234-
Type bbType = isa<ShapedType>(init.getType())
235-
? cast<ShapedType>(init.getType()).getElementType()
236-
: init.getType();
237-
producerBlock.addArgument(bbType, producer.getLoc());
238-
}
233+
for (auto init : producer.getDpsInits())
234+
producerBlock.addArgument(getElementTypeOrSelf(init.getType()),
235+
producer.getLoc());
239236
}
240237
Block &consumerBlock = consumer->getRegion(0).front();
241238
if (consumer.getOpOperandsMatchingBBargs() ==
242239
consumer.getDpsInputOperands()) {
243-
for (auto init : consumer.getDpsInits()) {
244-
Type bbType = isa<ShapedType>(init.getType())
245-
? cast<ShapedType>(init.getType()).getElementType()
246-
: init.getType();
247-
consumerBlock.addArgument(bbType, consumer.getLoc());
248-
}
240+
for (auto init : consumer.getDpsInits())
241+
consumerBlock.addArgument(getElementTypeOrSelf(init.getType()),
242+
consumer.getLoc());
249243
}
250244
OpBuilder::InsertionGuard guard(rewriter);
251245
Block *fusedBlock = rewriter.createBlock(&fusedOp->getRegion(0));

0 commit comments

Comments
 (0)