@@ -1570,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
15701570 return success ();
15711571}
15721572
1573- // Retrieve the operation from the body, if it is the only one (except
1574- // yield) and if it gets the same amount of arguments as the body does.
1575- // If initFirst flag is enabled, we check that init takes the first position in
1576- // operands of payload.
1577- static Operation * findPayloadOp (Block * body, bool initFirst = false ) {
1573+ static bool canUseShortForm (Block * body, bool initFirst = false ) {
1574+ // Check if the body can be printed in short form. The following 4 conditions
1575+ // must be satisfied:
1576+
1577+ // 1) The body must contain exactly 2 operations: the payload op and a yield.
15781578 if (body->getOperations ().size () != 2 )
1579- return nullptr ;
1579+ return false ;
15801580 Operation &payload = body->getOperations ().front ();
1581- assert (isa<YieldOp>(body->getOperations ().back ()));
15821581
1582+ // 2) The payload op must have the same number of operands as the number of
1583+ // block arguments.
15831584 if (payload.getNumOperands () == 0 ||
15841585 payload.getNumOperands () != body->getNumArguments ())
1585- return nullptr ;
1586+ return false ;
1587+
1588+ // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1589+ // must be the first operand of the payload op, otherwise, the operands
1590+ // must match the block arguments in order.
15861591 if (initFirst) {
15871592 // check init
15881593 if (payload.getOperands ().back () != body->getArgument (0 ))
1589- return nullptr ;
1594+ return false ;
15901595 // check rest
15911596 for (const auto &[operand, bbArg] :
15921597 llvm::zip (payload.getOperands (), body->getArguments ().drop_front ())) {
15931598 if (bbArg != operand)
1594- return nullptr ;
1599+ return false ;
15951600 }
15961601 } else {
15971602 for (const auto &[operand, bbArg] :
15981603 llvm::zip (payload.getOperands (), body->getArguments ())) {
15991604 if (bbArg != operand)
1600- return nullptr ;
1605+ return false ;
16011606 }
16021607 }
1603- return &payload;
1608+
1609+ // 4) The `yield` operand must be the result of the payload op.
1610+ auto yieldOp = cast<YieldOp>(body->getTerminator ());
1611+ return yieldOp.getNumOperands () == 1 &&
1612+ yieldOp.getOperand (0 ).getDefiningOp () &&
1613+ yieldOp.getOperand (0 ).getDefiningOp () == &payload;
16041614}
16051615
1606- void printShortForm (OpAsmPrinter &p, Operation *payloadOp) {
1616+ static void printShortForm (OpAsmPrinter &p, Operation *payloadOp) {
16071617 SmallVector<StringRef> elidedAttrs;
16081618 std::string attrToElide;
16091619 p << " { " << payloadOp->getName ().getStringRef ();
@@ -1622,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
16221632
16231633void MapOp::print (OpAsmPrinter &p) {
16241634 Block *mapper = getBody ();
1625- Operation *payloadOp = findPayloadOp (mapper);
1626- if (payloadOp ) {
1627- printShortForm (p, payloadOp );
1635+ bool useShortForm = canUseShortForm (mapper);
1636+ if (useShortForm ) {
1637+ printShortForm (p, &mapper-> getOperations (). front () );
16281638 }
16291639
16301640 printCommonStructuredOpParts (p, getDpsInputs (), getDpsInits ());
16311641 p.printOptionalAttrDict ((*this )->getAttrs ());
16321642
1633- if (!payloadOp ) {
1643+ if (!useShortForm ) {
16341644 // Print region if the payload op was not detected.
16351645 p.increaseIndent ();
16361646 p.printNewline ();
@@ -1829,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
18291839
18301840void ReduceOp::print (OpAsmPrinter &p) {
18311841 Block *mapper = getBody ();
1832- Operation *payloadOp = findPayloadOp (mapper, /* initFirst=*/ true );
1833- if (payloadOp ) {
1834- printShortForm (p, payloadOp );
1842+ bool useShortForm = canUseShortForm (mapper, /* initFirst=*/ true );
1843+ if (useShortForm ) {
1844+ printShortForm (p, &mapper-> getOperations (). front () );
18351845 }
18361846
18371847 printCommonStructuredOpParts (p, getDpsInputs (), getDpsInits ());
18381848 printDenseI64ArrayAttr (p, getDimensionsAttrName (), getDimensions ());
18391849 p.printOptionalAttrDict ((*this )->getAttrs (), {getDimensionsAttrName ()});
1840- if (!payloadOp ) {
1850+ if (!useShortForm ) {
18411851 // Print region if the payload op was not detected.
18421852 p.increaseIndent ();
18431853 p.printNewline ();
0 commit comments