@@ -1570,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1570
1570
return success ();
1571
1571
}
1572
1572
1573
- // Retrieve the operation from the body, if it is the only one (except
1574
- // yield) and if it gets the same amount of arguments as the body does.
1575
- // If initFirst flag is enabled, we check that init takes the first position in
1576
- // operands of payload.
1577
- static Operation * findPayloadOp (Block * body, bool initFirst = false ) {
1573
+ static bool canUseShortForm (Block * body, bool initFirst = false ) {
1574
+ // Check if the body can be printed in short form. The following 4 conditions
1575
+ // must be satisfied:
1576
+
1577
+ // 1) The body must contain exactly 2 operations: the payload op and a yield.
1578
1578
if (body->getOperations ().size () != 2 )
1579
- return nullptr ;
1579
+ return false ;
1580
1580
Operation &payload = body->getOperations ().front ();
1581
- assert (isa<YieldOp>(body->getOperations ().back ()));
1582
1581
1582
+ // 2) The payload op must have the same number of operands as the number of
1583
+ // block arguments.
1583
1584
if (payload.getNumOperands () == 0 ||
1584
1585
payload.getNumOperands () != body->getNumArguments ())
1585
- return nullptr ;
1586
+ return false ;
1587
+
1588
+ // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1589
+ // must be the first operand of the payload op, otherwise, the operands
1590
+ // must match the block arguments in order.
1586
1591
if (initFirst) {
1587
1592
// check init
1588
1593
if (payload.getOperands ().back () != body->getArgument (0 ))
1589
- return nullptr ;
1594
+ return false ;
1590
1595
// check rest
1591
1596
for (const auto &[operand, bbArg] :
1592
1597
llvm::zip (payload.getOperands (), body->getArguments ().drop_front ())) {
1593
1598
if (bbArg != operand)
1594
- return nullptr ;
1599
+ return false ;
1595
1600
}
1596
1601
} else {
1597
1602
for (const auto &[operand, bbArg] :
1598
1603
llvm::zip (payload.getOperands (), body->getArguments ())) {
1599
1604
if (bbArg != operand)
1600
- return nullptr ;
1605
+ return false ;
1601
1606
}
1602
1607
}
1603
- return &payload;
1608
+
1609
+ // 4) The `yield` operand must be the result of the payload op.
1610
+ auto yieldOp = cast<YieldOp>(body->getTerminator ());
1611
+ return yieldOp.getNumOperands () == 1 &&
1612
+ yieldOp.getOperand (0 ).getDefiningOp () &&
1613
+ yieldOp.getOperand (0 ).getDefiningOp () == &payload;
1604
1614
}
1605
1615
1606
- void printShortForm (OpAsmPrinter &p, Operation *payloadOp) {
1616
+ static void printShortForm (OpAsmPrinter &p, Operation *payloadOp) {
1607
1617
SmallVector<StringRef> elidedAttrs;
1608
1618
std::string attrToElide;
1609
1619
p << " { " << payloadOp->getName ().getStringRef ();
@@ -1622,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1622
1632
1623
1633
void MapOp::print (OpAsmPrinter &p) {
1624
1634
Block *mapper = getBody ();
1625
- Operation *payloadOp = findPayloadOp (mapper);
1626
- if (payloadOp ) {
1627
- printShortForm (p, payloadOp );
1635
+ bool useShortForm = canUseShortForm (mapper);
1636
+ if (useShortForm ) {
1637
+ printShortForm (p, &mapper-> getOperations (). front () );
1628
1638
}
1629
1639
1630
1640
printCommonStructuredOpParts (p, getDpsInputs (), getDpsInits ());
1631
1641
p.printOptionalAttrDict ((*this )->getAttrs ());
1632
1642
1633
- if (!payloadOp ) {
1643
+ if (!useShortForm ) {
1634
1644
// Print region if the payload op was not detected.
1635
1645
p.increaseIndent ();
1636
1646
p.printNewline ();
@@ -1829,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1829
1839
1830
1840
void ReduceOp::print (OpAsmPrinter &p) {
1831
1841
Block *mapper = getBody ();
1832
- Operation *payloadOp = findPayloadOp (mapper, /* initFirst=*/ true );
1833
- if (payloadOp ) {
1834
- printShortForm (p, payloadOp );
1842
+ bool useShortForm = canUseShortForm (mapper, /* initFirst=*/ true );
1843
+ if (useShortForm ) {
1844
+ printShortForm (p, &mapper-> getOperations (). front () );
1835
1845
}
1836
1846
1837
1847
printCommonStructuredOpParts (p, getDpsInputs (), getDpsInits ());
1838
1848
printDenseI64ArrayAttr (p, getDimensionsAttrName (), getDimensions ());
1839
1849
p.printOptionalAttrDict ((*this )->getAttrs (), {getDimensionsAttrName ()});
1840
- if (!payloadOp ) {
1850
+ if (!useShortForm ) {
1841
1851
// Print region if the payload op was not detected.
1842
1852
p.increaseIndent ();
1843
1853
p.printNewline ();
0 commit comments