Skip to content

Commit 16080a0

Browse files
committed
cleanup
1 parent c24793e commit 16080a0

File tree

1 file changed

+33
-57
lines changed

1 file changed

+33
-57
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 33 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,74 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
15701570
return success();
15711571
}
15721572

1573-
// Check if a block contains a single payload operation that can be printed in
1574-
// short form. The block must contain exactly 2 operations: the payload op and a
1575-
// yield.
1576-
static bool canUseShortForm(Block *body) {
1573+
static bool canUseShortForm(Block *body, bool initFirst = false) {
1574+
// Check if the body can be printed in short form. The following 4 conditions
1575+
// must be satisfied:
1576+
1577+
// 1) The body must contain exactly 2 operations: the payload op and a yield.
15771578
if (body->getOperations().size() != 2)
15781579
return false;
1579-
15801580
Operation &payload = body->getOperations().front();
1581-
assert(isa<YieldOp>(body->getOperations().back()));
1582-
1583-
// Check that the yield has exactly one operand that comes from the payload
1584-
auto yieldOp = cast<YieldOp>(body->getOperations().back());
1585-
if (yieldOp.getNumOperands() != 1)
1586-
return false;
15871581

1588-
Value yieldOperand = yieldOp.getOperand(0);
1589-
if (!yieldOperand.getDefiningOp() || yieldOperand.getDefiningOp() != &payload)
1582+
// 2) The payload op must have the same number of operands as the number of
1583+
// block arguments.
1584+
if (payload.getNumOperands() == 0 ||
1585+
payload.getNumOperands() != body->getNumArguments())
15901586
return false;
15911587

1592-
return true;
1593-
}
1594-
1595-
// Find a payload operation that can be printed in short form.
1596-
// For MapOp (initFirst=false): operands must match block arguments in order.
1597-
// For ReduceOp (initFirst=true): init operand must be first, then operands must
1598-
// match block arguments.
1599-
static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1600-
if (!canUseShortForm(body))
1601-
return nullptr;
1602-
1603-
Operation &payload = body->getOperations().front();
1604-
1588+
// 3) If `initFirst` is true (e.g., for reduction ops), the init block
1589+
// must be the first operand of the payload op, otherwise, the operands
1590+
// must match the block arguments in order.
16051591
if (initFirst) {
1606-
// For ReduceOp: check that operand count matches block argument count + 1
1607-
// (for init)
1608-
if (payload.getNumOperands() == 0 ||
1609-
payload.getNumOperands() != body->getNumArguments() + 1)
1610-
return nullptr;
1611-
1612-
// Check that init operand is first
1613-
if (payload.getOperands().front() != body->getArgument(0))
1614-
return nullptr;
1615-
1616-
// Check that remaining operands match block arguments in order
1592+
// check init
1593+
if (payload.getOperands().back() != body->getArgument(0))
1594+
return false;
1595+
// check rest
16171596
for (const auto &[operand, bbArg] :
1618-
llvm::zip(payload.getOperands().drop_front(),
1619-
body->getArguments().drop_front())) {
1597+
llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
16201598
if (bbArg != operand)
1621-
return nullptr;
1599+
return false;
16221600
}
16231601
} else {
1624-
// For MapOp: check that operand count matches block argument count
1625-
if (payload.getNumOperands() == 0 ||
1626-
payload.getNumOperands() != body->getNumArguments())
1627-
return nullptr;
1628-
1629-
// Check that operands match block arguments in order
16301602
for (const auto &[operand, bbArg] :
16311603
llvm::zip(payload.getOperands(), body->getArguments())) {
16321604
if (bbArg != operand)
1633-
return nullptr;
1605+
return false;
16341606
}
16351607
}
16361608

1637-
return &payload;
1609+
// 4) The `yield` operand must be the result of the payload op.
1610+
auto yieldOp = cast<YieldOp>(body->getTerminator());
1611+
return yieldOp.getNumOperands() == 1 &&
1612+
yieldOp.getOperand(0).getDefiningOp() &&
1613+
yieldOp.getOperand(0).getDefiningOp() == &payload;
16381614
}
16391615

1640-
void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1616+
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
16411617
SmallVector<StringRef> elidedAttrs;
16421618
std::string attrToElide;
16431619
p << " { " << payloadOp->getName().getStringRef();
@@ -1656,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
16561632

16571633
void MapOp::print(OpAsmPrinter &p) {
16581634
Block *mapper = getBody();
1659-
Operation *payloadOp = findPayloadOp(mapper);
1660-
if (payloadOp) {
1661-
printShortForm(p, payloadOp);
1635+
bool useShortForm = canUseShortForm(mapper);
1636+
if (useShortForm) {
1637+
printShortForm(p, &mapper->getOperations().front());
16621638
}
16631639

16641640
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
16651641
p.printOptionalAttrDict((*this)->getAttrs());
16661642

1667-
if (!payloadOp) {
1643+
if (!useShortForm) {
16681644
// Print region if the payload op was not detected.
16691645
p.increaseIndent();
16701646
p.printNewline();
@@ -1863,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
18631839

18641840
void ReduceOp::print(OpAsmPrinter &p) {
18651841
Block *mapper = getBody();
1866-
Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1867-
if (payloadOp) {
1868-
printShortForm(p, payloadOp);
1842+
bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
1843+
if (useShortForm) {
1844+
printShortForm(p, &mapper->getOperations().front());
18691845
}
18701846

18711847
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
18721848
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
18731849
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1874-
if (!payloadOp) {
1850+
if (!useShortForm) {
18751851
// Print region if the payload op was not detected.
18761852
p.increaseIndent();
18771853
p.printNewline();

0 commit comments

Comments
 (0)