@@ -1570,74 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
15701570 return success ();
15711571}
15721572
1573- // Check if a block contains a single payload operation that can be printed in
1574- // short form. The block must contain exactly 2 operations: the payload op and a
1575- // yield.
1576- static bool canUseShortForm (Block *body) {
1573+ static bool canUseShortForm (Block *body, bool initFirst = false ) {
1574+ // Check if the body can be printed in short form. The following 4 conditions
1575+ // must be satisfied:
1576+
1577+ // 1) The body must contain exactly 2 operations: the payload op and a yield.
15771578 if (body->getOperations ().size () != 2 )
15781579 return false ;
1579-
15801580 Operation &payload = body->getOperations ().front ();
1581- assert (isa<YieldOp>(body->getOperations ().back ()));
1582-
1583- // Check that the yield has exactly one operand that comes from the payload
1584- auto yieldOp = cast<YieldOp>(body->getOperations ().back ());
1585- if (yieldOp.getNumOperands () != 1 )
1586- return false ;
15871581
1588- Value yieldOperand = yieldOp.getOperand (0 );
1589- if (!yieldOperand.getDefiningOp () || yieldOperand.getDefiningOp () != &payload)
1582+ // 2) The payload op must have the same number of operands as the number of
1583+ // block arguments.
1584+ if (payload.getNumOperands () == 0 ||
1585+ payload.getNumOperands () != body->getNumArguments ())
15901586 return false ;
15911587
1592- return true ;
1593- }
1594-
1595- // Find a payload operation that can be printed in short form.
1596- // For MapOp (initFirst=false): operands must match block arguments in order.
1597- // For ReduceOp (initFirst=true): init operand must be first, then operands must
1598- // match block arguments.
1599- static Operation *findPayloadOp (Block *body, bool initFirst = false ) {
1600- if (!canUseShortForm (body))
1601- return nullptr ;
1602-
1603- Operation &payload = body->getOperations ().front ();
1604-
1588+ // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1589+ // must be the first operand of the payload op, otherwise, the operands
1590+ // must match the block arguments in order.
16051591 if (initFirst) {
1606- // For ReduceOp: check that operand count matches block argument count + 1
1607- // (for init)
1608- if (payload.getNumOperands () == 0 ||
1609- payload.getNumOperands () != body->getNumArguments () + 1 )
1610- return nullptr ;
1611-
1612- // Check that init operand is first
1613- if (payload.getOperands ().front () != body->getArgument (0 ))
1614- return nullptr ;
1615-
1616- // Check that remaining operands match block arguments in order
1592+ // check init
1593+ if (payload.getOperands ().back () != body->getArgument (0 ))
1594+ return false ;
1595+ // check rest
16171596 for (const auto &[operand, bbArg] :
1618- llvm::zip (payload.getOperands ().drop_front (),
1619- body->getArguments ().drop_front ())) {
1597+ llvm::zip (payload.getOperands (), body->getArguments ().drop_front ())) {
16201598 if (bbArg != operand)
1621- return nullptr ;
1599+ return false ;
16221600 }
16231601 } else {
1624- // For MapOp: check that operand count matches block argument count
1625- if (payload.getNumOperands () == 0 ||
1626- payload.getNumOperands () != body->getNumArguments ())
1627- return nullptr ;
1628-
1629- // Check that operands match block arguments in order
16301602 for (const auto &[operand, bbArg] :
16311603 llvm::zip (payload.getOperands (), body->getArguments ())) {
16321604 if (bbArg != operand)
1633- return nullptr ;
1605+ return false ;
16341606 }
16351607 }
16361608
1637- return &payload;
1609+ // 4) The `yield` operand must be the result of the payload op.
1610+ auto yieldOp = cast<YieldOp>(body->getTerminator ());
1611+ return yieldOp.getNumOperands () == 1 &&
1612+ yieldOp.getOperand (0 ).getDefiningOp () &&
1613+ yieldOp.getOperand (0 ).getDefiningOp () == &payload;
16381614}
16391615
1640- void printShortForm (OpAsmPrinter &p, Operation *payloadOp) {
1616+ static void printShortForm (OpAsmPrinter &p, Operation *payloadOp) {
16411617 SmallVector<StringRef> elidedAttrs;
16421618 std::string attrToElide;
16431619 p << " { " << payloadOp->getName ().getStringRef ();
@@ -1656,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
16561632
16571633void MapOp::print (OpAsmPrinter &p) {
16581634 Block *mapper = getBody ();
1659- Operation *payloadOp = findPayloadOp (mapper);
1660- if (payloadOp ) {
1661- printShortForm (p, payloadOp );
1635+ bool useShortForm = canUseShortForm (mapper);
1636+ if (useShortForm ) {
1637+ printShortForm (p, &mapper-> getOperations (). front () );
16621638 }
16631639
16641640 printCommonStructuredOpParts (p, getDpsInputs (), getDpsInits ());
16651641 p.printOptionalAttrDict ((*this )->getAttrs ());
16661642
1667- if (!payloadOp ) {
1643+ if (!useShortForm ) {
16681644 // Print region if the payload op was not detected.
16691645 p.increaseIndent ();
16701646 p.printNewline ();
@@ -1863,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
18631839
18641840void ReduceOp::print (OpAsmPrinter &p) {
18651841 Block *mapper = getBody ();
1866- Operation *payloadOp = findPayloadOp (mapper, /* initFirst=*/ true );
1867- if (payloadOp ) {
1868- printShortForm (p, payloadOp );
1842+ bool useShortForm = canUseShortForm (mapper, /* initFirst=*/ true );
1843+ if (useShortForm ) {
1844+ printShortForm (p, &mapper-> getOperations (). front () );
18691845 }
18701846
18711847 printCommonStructuredOpParts (p, getDpsInputs (), getDpsInits ());
18721848 printDenseI64ArrayAttr (p, getDimensionsAttrName (), getDimensions ());
18731849 p.printOptionalAttrDict ((*this )->getAttrs (), {getDimensionsAttrName ()});
1874- if (!payloadOp ) {
1850+ if (!useShortForm ) {
18751851 // Print region if the payload op was not detected.
18761852 p.increaseIndent ();
18771853 p.printNewline ();
0 commit comments