Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1570,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}

// Retrieve the operation from the body, if it is the only one (except
// yield) and if it gets the same amount of arguments as the body does.
// If initFirst flag is enabled, we check that init takes the first position in
// operands of payload.
static Operation *findPayloadOp(Block *body, bool initFirst = false) {
static bool canUseShortForm(Block *body, bool initFirst = false) {
// Check if the body can be printed in short form. The following 4 conditions
// must be satisfied:

// 1) The body must contain exactly 2 operations: the payload op and a yield.
if (body->getOperations().size() != 2)
return nullptr;
return false;
Operation &payload = body->getOperations().front();
assert(isa<YieldOp>(body->getOperations().back()));

// 2) The payload op must have the same number of operands as the number of
// block arguments.
if (payload.getNumOperands() == 0 ||
payload.getNumOperands() != body->getNumArguments())
return nullptr;
return false;

// 3) If `initFirst` is true (e.g., for reduction ops), the init block
// must be the first operand of the payload op, otherwise, the operands
// must match the block arguments in order.
if (initFirst) {
// check init
if (payload.getOperands().back() != body->getArgument(0))
return nullptr;
return false;
// check rest
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
if (bbArg != operand)
return nullptr;
return false;
}
} else {
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments())) {
if (bbArg != operand)
return nullptr;
return false;
}
}
return &payload;

// 4) The `yield` operand must be the result of the payload op.
auto yieldOp = cast<YieldOp>(body->getTerminator());
return yieldOp.getNumOperands() == 1 &&
yieldOp.getOperand(0).getDefiningOp() &&
yieldOp.getOperand(0).getDefiningOp() == &payload;
}

void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
SmallVector<StringRef> elidedAttrs;
std::string attrToElide;
p << " { " << payloadOp->getName().getStringRef();
Expand All @@ -1622,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {

void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
Operation *payloadOp = findPayloadOp(mapper);
if (payloadOp) {
printShortForm(p, payloadOp);
bool useShortForm = canUseShortForm(mapper);
if (useShortForm) {
printShortForm(p, &mapper->getOperations().front());
}

printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
p.printOptionalAttrDict((*this)->getAttrs());

if (!payloadOp) {
if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
Expand Down Expand Up @@ -1829,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,

void ReduceOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
if (payloadOp) {
printShortForm(p, payloadOp);
bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
if (useShortForm) {
printShortForm(p, &mapper->getOperations().front());
}

printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
if (!payloadOp) {
if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
Expand Down
51 changes: 51 additions & 0 deletions mlir/test/Dialect/Linalg/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,34 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
// CHECK-SAME: outs
// CHECK-SAME: dimensions = [1]


// -----


func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
%reduce = linalg.reduce
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<16x64xf32>)
dimensions = [1]
(%in1: f32, %in2: f32) {
%0 = arith.addf %in1, %in2: f32
linalg.yield %in1: f32
}
func.return %reduce : tensor<16x64xf32>
}

// CHECK-LABEL: func @reduce_not_short_form_compatible
// CHECK-SAME: %[[ARG0:.*]]: tensor<16x32x64xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<16x64xf32>
// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[ARG0]] : tensor<16x32x64xf32>
// CHECK: linalg.reduce ins(%[[ARG0]] : tensor<16x32x64xf32>) outs(%[[ARG1]] : tensor<16x64xf32>)
// CHECK-SAME: dimensions = [1]
// CHECK: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
// CHECK-NEXT: %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
// CHECK-NEXT: }

// -----

func.func @reduce_memref(%input: memref<16x32x64xf32>,
Expand Down Expand Up @@ -592,6 +620,29 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,

// -----

func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
%res = tensor.empty() : tensor<1x32xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit 1] Use meaningful and consistent names for function arguments.
[nit 2] Keep the number of ops to the required minimum.

Suggested change
func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
%res = tensor.empty() : tensor<1x32xf32>
func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %res: tensor<1x32xf32>) -> tensor<1x32xf32> {
%res = tensor.empty() : tensor<1x32xf32>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks for the suggestion!

%mapped = linalg.map ins(%arg0, %arg1 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%res : tensor<1x32xf32>)
(%in_1: f32, %in_2: f32) {
%1 = arith.maximumf %in_1, %in_2 : f32
linalg.yield %in_1 : f32
}
func.return %mapped : tensor<1x32xf32>
}

// CHECK-LABEL: func @map_not_short_form_compatible
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x32xf32>, %[[ARG1:.*]]: tensor<1x32xf32>
// CHECK: %[[RES:.*]] = tensor.empty() : tensor<1x32xf32>
// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[ARG0]] : tensor<1x32xf32>
// CHECK: linalg.map ins(%[[ARG0]], %[[ARG1]] : tensor<1x32xf32>, tensor<1x32xf32>)
// CHECK-SAME: outs(%[[RES]] : tensor<1x32xf32>)
// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
// CHECK-NEXT: }

// -----

func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
%reduce = linalg.reduce
Expand Down