Skip to content

Commit 1945753

Browse files
authored
[mlir][linalg] Fix incorrect linalg short form printing (#153219)
Both `linalg.map` and `linalg.reduce` are sometimes printed in short form incorrectly, resulting in a round-trip output with different semantics. This patch adds additional `yield` operand checks to ensure that all criteria for short-form printing are satisfied. Updated/added comments and renamed the `findPayloadOp` function to `canUseShortForm`, which more accurately reflects its purpose. A couple of new lit tests check for the proper use of long form when short-form conditions are not met. Fixes #117528
1 parent ec237da commit 1945753

File tree

3 files changed

+92
-29
lines changed

3 files changed

+92
-29
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,12 @@ def MapOp : LinalgStructuredBase_Op<"map", [
253253
}
254254
```
255255

256-
Shortened print form is available. Applies to simple maps with one
257-
non-yield operation inside the body.
256+
Shortened print form is available for simple maps where the body contains exactly
257+
two operations (the payload operation and a yield), the payload operation has
258+
the same number of operands as block arguments with operands matching block
259+
arguments in order, and the yield operand is the result of the payload operation.
258260

259-
The example above will be printed as:
261+
The example above will be printed using the shortened form as:
260262
```mlir
261263
%add = linalg.map { arith.addf }
262264
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
@@ -340,13 +342,15 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
340342
}
341343
```
342344

343-
Shortened print form is available. Applies to simple (not variadic) reduces
344-
with one non-yield operation inside the body. Applies only if the operation
345-
takes `%out` as the first argument.
345+
Shortened print form is available for simple reduces where the body contains exactly
346+
two operations (the payload operation and a yield), the payload operation has the
347+
same number of operands as block arguments, the first block argument (init) is the
348+
last operand of the payload operation with remaining operands matching remaining
349+
block arguments in order, and the yield operand is the result of the payload operation.
346350

347-
The example above will be printed as:
351+
The example above will be printed using the shortened form as:
348352
```mlir
349-
%reduce = linalg.reduce { arith.addf }
353+
%reduce = linalg.reduce { arith.addf }
350354
ins(%input:tensor<16x32x64xf32>)
351355
outs(%init:tensor<16x64xf32>)
352356
dimensions = [1]

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
15701570
return success();
15711571
}
15721572

1573-
// Retrieve the operation from the body, if it is the only one (except
1574-
// yield) and if it gets the same amount of arguments as the body does.
1575-
// If initFirst flag is enabled, we check that init takes the first position in
1576-
// operands of payload.
1577-
static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1573+
static bool canUseShortForm(Block *body, bool initFirst = false) {
1574+
// Check if the body can be printed in short form. The following 4 conditions
1575+
// must be satisfied:
1576+
1577+
// 1) The body must contain exactly 2 operations: the payload op and a yield.
15781578
if (body->getOperations().size() != 2)
1579-
return nullptr;
1579+
return false;
15801580
Operation &payload = body->getOperations().front();
1581-
assert(isa<YieldOp>(body->getOperations().back()));
15821581

1582+
// 2) The payload op must have the same number of operands as the number of
1583+
// block arguments.
15831584
if (payload.getNumOperands() == 0 ||
15841585
payload.getNumOperands() != body->getNumArguments())
1585-
return nullptr;
1586+
return false;
1587+
1588+
// 3) If `initFirst` is true (e.g., for reduction ops), the init block
1589+
// must be the first operand of the payload op, otherwise, the operands
1590+
// must match the block arguments in order.
15861591
if (initFirst) {
15871592
// check init
15881593
if (payload.getOperands().back() != body->getArgument(0))
1589-
return nullptr;
1594+
return false;
15901595
// check rest
15911596
for (const auto &[operand, bbArg] :
15921597
llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
15931598
if (bbArg != operand)
1594-
return nullptr;
1599+
return false;
15951600
}
15961601
} else {
15971602
for (const auto &[operand, bbArg] :
15981603
llvm::zip(payload.getOperands(), body->getArguments())) {
15991604
if (bbArg != operand)
1600-
return nullptr;
1605+
return false;
16011606
}
16021607
}
1603-
return &payload;
1608+
1609+
// 4) The `yield` operand must be the result of the payload op.
1610+
auto yieldOp = cast<YieldOp>(body->getTerminator());
1611+
return yieldOp.getNumOperands() == 1 &&
1612+
yieldOp.getOperand(0).getDefiningOp() &&
1613+
yieldOp.getOperand(0).getDefiningOp() == &payload;
16041614
}
16051615

1606-
void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1616+
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
16071617
SmallVector<StringRef> elidedAttrs;
16081618
std::string attrToElide;
16091619
p << " { " << payloadOp->getName().getStringRef();
@@ -1622,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
16221632

16231633
void MapOp::print(OpAsmPrinter &p) {
16241634
Block *mapper = getBody();
1625-
Operation *payloadOp = findPayloadOp(mapper);
1626-
if (payloadOp) {
1627-
printShortForm(p, payloadOp);
1635+
bool useShortForm = canUseShortForm(mapper);
1636+
if (useShortForm) {
1637+
printShortForm(p, &mapper->getOperations().front());
16281638
}
16291639

16301640
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
16311641
p.printOptionalAttrDict((*this)->getAttrs());
16321642

1633-
if (!payloadOp) {
1643+
if (!useShortForm) {
16341644
// Print region if the payload op was not detected.
16351645
p.increaseIndent();
16361646
p.printNewline();
@@ -1829,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
18291839

18301840
void ReduceOp::print(OpAsmPrinter &p) {
18311841
Block *mapper = getBody();
1832-
Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1833-
if (payloadOp) {
1834-
printShortForm(p, payloadOp);
1842+
bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
1843+
if (useShortForm) {
1844+
printShortForm(p, &mapper->getOperations().front());
18351845
}
18361846

18371847
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
18381848
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
18391849
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1840-
if (!payloadOp) {
1850+
if (!useShortForm) {
18411851
// Print region if the payload op was not detected.
18421852
p.increaseIndent();
18431853
p.printNewline();

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,34 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
436436
// CHECK-SAME: outs
437437
// CHECK-SAME: dimensions = [1]
438438

439+
440+
// -----
441+
442+
443+
func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
444+
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
445+
%reduce = linalg.reduce
446+
ins(%input:tensor<16x32x64xf32>)
447+
outs(%init:tensor<16x64xf32>)
448+
dimensions = [1]
449+
(%in1: f32, %in2: f32) {
450+
%0 = arith.addf %in1, %in2: f32
451+
linalg.yield %in1: f32
452+
}
453+
func.return %reduce : tensor<16x64xf32>
454+
}
455+
456+
// CHECK-LABEL: func @reduce_not_short_form_compatible
457+
// CHECK-SAME: %[[INPUT:.*]]: tensor<16x32x64xf32>
458+
// CHECK-SAME: %[[INIT:.*]]: tensor<16x64xf32>
459+
// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<16x32x64xf32>
460+
// CHECK: linalg.reduce ins(%[[INPUT]] : tensor<16x32x64xf32>) outs(%[[INIT]] : tensor<16x64xf32>)
461+
// CHECK-SAME: dimensions = [1]
462+
// CHECK: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
463+
// CHECK-NEXT: %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
464+
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
465+
// CHECK-NEXT: }
466+
439467
// -----
440468

441469
func.func @reduce_memref(%input: memref<16x32x64xf32>,
@@ -592,6 +620,27 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
592620

593621
// -----
594622

623+
func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
624+
%mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
625+
(%in_1: f32, %in_2: f32) {
626+
%1 = arith.maximumf %in_1, %in_2 : f32
627+
linalg.yield %in_1 : f32
628+
}
629+
func.return %mapped : tensor<1x32xf32>
630+
}
631+
632+
// CHECK-LABEL: func @map_not_short_form_compatible
633+
// CHECK-SAME: %[[LHS:.*]]: tensor<1x32xf32>, %[[RHS:.*]]: tensor<1x32xf32>, %[[INIT:.*]]: tensor<1x32xf32>
634+
// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
635+
// CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>)
636+
// CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>)
637+
// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
638+
// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
639+
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
640+
// CHECK-NEXT: }
641+
642+
// -----
643+
595644
func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
596645
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
597646
%reduce = linalg.reduce

0 commit comments

Comments
 (0)