Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,12 @@ def MapOp : LinalgStructuredBase_Op<"map", [
}
```

Shortened print form is available. Applies to simple maps with one
non-yield operation inside the body.
Shortened print form is available for simple maps where the body contains exactly
two operations (the payload operation and a yield), the payload operation has
the same number of operands as block arguments with operands matching block
arguments in order, and the yield operand is the result of the payload operation.

The example above will be printed as:
The example above will be printed using the shortened form as:
```mlir
%add = linalg.map { arith.addf }
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
Expand Down Expand Up @@ -340,13 +342,15 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
}
```

Shortened print form is available. Applies to simple (not variadic) reduces
with one non-yield operation inside the body. Applies only if the operation
takes `%out` as the first argument.
Shortened print form is available for simple reduces where the body contains exactly
two operations (the payload operation and a yield), the payload operation has the
same number of operands as block arguments, the first block argument (init) is the
last operand of the payload operation with remaining operands matching remaining
block arguments in order, and the yield operand is the result of the payload operation.

The example above will be printed as:
The example above will be printed using the shortened form as:
```mlir
%reduce = linalg.reduce { arith.addf }
%reduce = linalg.reduce { arith.addf }
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<16x64xf32>)
dimensions = [1]
Expand Down
52 changes: 31 additions & 21 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1570,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}

// Retrieve the operation from the body, if it is the only one (except
// yield) and if it gets the same amount of arguments as the body does.
// If initFirst flag is enabled, we check that init takes the first position in
// operands of payload.
static Operation *findPayloadOp(Block *body, bool initFirst = false) {
static bool canUseShortForm(Block *body, bool initFirst = false) {
// Check if the body can be printed in short form. The following 4 conditions
// must be satisfied:

// 1) The body must contain exactly 2 operations: the payload op and a yield.
if (body->getOperations().size() != 2)
return nullptr;
return false;
Operation &payload = body->getOperations().front();
assert(isa<YieldOp>(body->getOperations().back()));

// 2) The payload op must have the same number of operands as the number of
// block arguments.
if (payload.getNumOperands() == 0 ||
payload.getNumOperands() != body->getNumArguments())
return nullptr;
return false;

// 3) If `initFirst` is true (e.g., for reduction ops), the init block
// must be the first operand of the payload op, otherwise, the operands
// must match the block arguments in order.
if (initFirst) {
// check init
if (payload.getOperands().back() != body->getArgument(0))
return nullptr;
return false;
// check rest
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
if (bbArg != operand)
return nullptr;
return false;
}
} else {
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments())) {
if (bbArg != operand)
return nullptr;
return false;
}
}
return &payload;

// 4) The `yield` operand must be the result of the payload op.
auto yieldOp = cast<YieldOp>(body->getTerminator());
return yieldOp.getNumOperands() == 1 &&
yieldOp.getOperand(0).getDefiningOp() &&
yieldOp.getOperand(0).getDefiningOp() == &payload;
}

void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
SmallVector<StringRef> elidedAttrs;
std::string attrToElide;
p << " { " << payloadOp->getName().getStringRef();
Expand All @@ -1622,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {

void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
Operation *payloadOp = findPayloadOp(mapper);
if (payloadOp) {
printShortForm(p, payloadOp);
bool useShortForm = canUseShortForm(mapper);
if (useShortForm) {
printShortForm(p, &mapper->getOperations().front());
}

printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
p.printOptionalAttrDict((*this)->getAttrs());

if (!payloadOp) {
if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
Expand Down Expand Up @@ -1829,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,

void ReduceOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
if (payloadOp) {
printShortForm(p, payloadOp);
bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
if (useShortForm) {
printShortForm(p, &mapper->getOperations().front());
}

printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
if (!payloadOp) {
if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
Expand Down
49 changes: 49 additions & 0 deletions mlir/test/Dialect/Linalg/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,34 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
// CHECK-SAME: outs
// CHECK-SAME: dimensions = [1]


// -----


func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
%reduce = linalg.reduce
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<16x64xf32>)
dimensions = [1]
(%in1: f32, %in2: f32) {
%0 = arith.addf %in1, %in2: f32
linalg.yield %in1: f32
}
func.return %reduce : tensor<16x64xf32>
}

// CHECK-LABEL: func @reduce_not_short_form_compatible
// CHECK-SAME: %[[INPUT:.*]]: tensor<16x32x64xf32>
// CHECK-SAME: %[[INIT:.*]]: tensor<16x64xf32>
// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<16x32x64xf32>
// CHECK: linalg.reduce ins(%[[INPUT]] : tensor<16x32x64xf32>) outs(%[[INIT]] : tensor<16x64xf32>)
// CHECK-SAME: dimensions = [1]
// CHECK: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
// CHECK-NEXT: %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
// CHECK-NEXT: }

// -----

func.func @reduce_memref(%input: memref<16x32x64xf32>,
Expand Down Expand Up @@ -592,6 +620,27 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,

// -----

func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
%mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
(%in_1: f32, %in_2: f32) {
%1 = arith.maximumf %in_1, %in_2 : f32
linalg.yield %in_1 : f32
}
func.return %mapped : tensor<1x32xf32>
}

// CHECK-LABEL: func @map_not_short_form_compatible
// CHECK-SAME: %[[LHS:.*]]: tensor<1x32xf32>, %[[RHS:.*]]: tensor<1x32xf32>, %[[INIT:.*]]: tensor<1x32xf32>
// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
// CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>)
// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
// CHECK-NEXT: }

// -----

func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
%reduce = linalg.reduce
Expand Down