Skip to content

Commit 5189c8f

Browse files
committed
address comments
1 parent 233b8a4 commit 5189c8f

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,12 @@ def MapOp : LinalgStructuredBase_Op<"map", [
253253
}
254254
```
255255

256-
Shortened print form is available. Applies to simple maps with one
257-
non-yield operation inside the body.
256+
Shortened print form is available for simple maps where the body contains exactly
257+
two operations (the payload operation and a yield), the payload operation has
258+
the same number of operands as block arguments with operands matching block
259+
arguments in order, and the yield operand is the result of the payload operation.
258260

259-
The example above will be printed as:
261+
The example above will be printed using the shortened form as:
260262
```mlir
261263
%add = linalg.map { arith.addf }
262264
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
@@ -340,13 +342,15 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
340342
}
341343
```
342344

343-
Shortened print form is available. Applies to simple (not variadic) reduces
344-
with one non-yield operation inside the body. Applies only if the operation
345-
takes `%out` as the first argument.
345+
Shortened print form is available for simple reduces where the body contains exactly
346+
two operations (the payload operation and a yield), the payload operation has the
347+
same number of operands as block arguments, the first block argument (init) is the
348+
last operand of the payload operation with remaining operands matching remaining
349+
block arguments in order, and the yield operand is the result of the payload operation.
346350

347-
The example above will be printed as:
351+
The example above will be printed using the shortened form as:
348352
```mlir
349-
%reduce = linalg.reduce { arith.addf }
353+
%reduce = linalg.reduce { arith.addf }
350354
ins(%input:tensor<16x32x64xf32>)
351355
outs(%init:tensor<16x64xf32>)
352356
dimensions = [1]

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -454,10 +454,10 @@ func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
454454
}
455455

456456
// CHECK-LABEL: func @reduce_not_short_form_compatible
457-
// CHECK-SAME: %[[ARG0:.*]]: tensor<16x32x64xf32>
458-
// CHECK-SAME: %[[ARG1:.*]]: tensor<16x64xf32>
459-
// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[ARG0]] : tensor<16x32x64xf32>
460-
// CHECK: linalg.reduce ins(%[[ARG0]] : tensor<16x32x64xf32>) outs(%[[ARG1]] : tensor<16x64xf32>)
457+
// CHECK-SAME: %[[INPUT:.*]]: tensor<16x32x64xf32>
458+
// CHECK-SAME: %[[INIT:.*]]: tensor<16x64xf32>
459+
// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<16x32x64xf32>
460+
// CHECK: linalg.reduce ins(%[[INPUT]] : tensor<16x32x64xf32>) outs(%[[INIT]] : tensor<16x64xf32>)
461461
// CHECK-SAME: dimensions = [1]
462462
// CHECK: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
463463
// CHECK-NEXT: %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
@@ -620,9 +620,8 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
620620

621621
// -----
622622

623-
func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
624-
%res = tensor.empty() : tensor<1x32xf32>
625-
%mapped = linalg.map ins(%arg0, %arg1 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%res : tensor<1x32xf32>)
623+
func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
624+
%mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
626625
(%in_1: f32, %in_2: f32) {
627626
%1 = arith.maximumf %in_1, %in_2 : f32
628627
linalg.yield %in_1 : f32
@@ -631,11 +630,10 @@ func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<
631630
}
632631

633632
// CHECK-LABEL: func @map_not_short_form_compatible
634-
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x32xf32>, %[[ARG1:.*]]: tensor<1x32xf32>
635-
// CHECK: %[[RES:.*]] = tensor.empty() : tensor<1x32xf32>
636-
// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[ARG0]] : tensor<1x32xf32>
637-
// CHECK: linalg.map ins(%[[ARG0]], %[[ARG1]] : tensor<1x32xf32>, tensor<1x32xf32>)
638-
// CHECK-SAME: outs(%[[RES]] : tensor<1x32xf32>)
633+
// CHECK-SAME: %[[LHS:.*]]: tensor<1x32xf32>, %[[RHS:.*]]: tensor<1x32xf32>, %[[INIT:.*]]: tensor<1x32xf32>
634+
// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
635+
// CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>)
636+
// CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>)
639637
// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
640638
// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
641639
// CHECK-NEXT: linalg.yield %[[IN1]] : f32

0 commit comments

Comments
 (0)