Skip to content

Commit e50edb7

Browse files
committed
add lit tests for reduce and map cases that can only use long form printing
1 parent 16080a0 commit e50edb7

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,34 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
436436
// CHECK-SAME: outs
437437
// CHECK-SAME: dimensions = [1]
438438

439+
440+
// -----
441+
442+
443+
func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
444+
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
445+
%reduce = linalg.reduce
446+
ins(%input:tensor<16x32x64xf32>)
447+
outs(%init:tensor<16x64xf32>)
448+
dimensions = [1]
449+
(%in1: f32, %in2: f32) {
450+
%0 = arith.addf %in1, %in2: f32
451+
linalg.yield %in1: f32
452+
}
453+
func.return %reduce : tensor<16x64xf32>
454+
}
455+
456+
// CHECK-LABEL: func @reduce_not_short_form_compatible
457+
// CHECK-SAME: %[[ARG0:.*]]: tensor<16x32x64xf32>
458+
// CHECK-SAME: %[[ARG1:.*]]: tensor<16x64xf32>
459+
// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[ARG0]] : tensor<16x32x64xf32>
460+
// CHECK: linalg.reduce ins(%[[ARG0]] : tensor<16x32x64xf32>) outs(%[[ARG1]] : tensor<16x64xf32>)
461+
// CHECK-SAME: dimensions = [1]
462+
// CHECK: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
463+
// CHECK-NEXT: %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
464+
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
465+
// CHECK-NEXT: }
466+
439467
// -----
440468

441469
func.func @reduce_memref(%input: memref<16x32x64xf32>,
@@ -592,6 +620,29 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
592620

593621
// -----
594622

623+
func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
624+
%res = tensor.empty() : tensor<1x32xf32>
625+
%mapped = linalg.map ins(%arg0, %arg1 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%res : tensor<1x32xf32>)
626+
(%in_1: f32, %in_2: f32) {
627+
%1 = arith.maximumf %in_1, %in_2 : f32
628+
linalg.yield %in_1 : f32
629+
}
630+
func.return %mapped : tensor<1x32xf32>
631+
}
632+
633+
// CHECK-LABEL: func @map_not_short_form_compatible
634+
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x32xf32>, %[[ARG1:.*]]: tensor<1x32xf32>
635+
// CHECK: %[[RES:.*]] = tensor.empty() : tensor<1x32xf32>
636+
// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[ARG0]] : tensor<1x32xf32>
637+
// CHECK: linalg.map ins(%[[ARG0]], %[[ARG1]] : tensor<1x32xf32>, tensor<1x32xf32>)
638+
// CHECK-SAME: outs(%[[RES]] : tensor<1x32xf32>)
639+
// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
640+
// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
641+
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
642+
// CHECK-NEXT: }
643+
644+
// -----
645+
595646
func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
596647
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
597648
%reduce = linalg.reduce

0 commit comments

Comments
 (0)