@@ -454,10 +454,10 @@ func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
454454}
455455
456456// CHECK-LABEL: func @reduce_not_short_form_compatible
457- // CHECK-SAME: %[[ARG0 :.*]]: tensor<16x32x64xf32>
458- // CHECK-SAME: %[[ARG1 :.*]]: tensor<16x64xf32>
459- // CHECK-NOT: linalg.reduce { arith.addf } ins(%[[ARG0 ]] : tensor<16x32x64xf32>
460- // CHECK: linalg.reduce ins(%[[ARG0 ]] : tensor<16x32x64xf32>) outs(%[[ARG1 ]] : tensor<16x64xf32>)
457+ // CHECK-SAME: %[[INPUT :.*]]: tensor<16x32x64xf32>
458+ // CHECK-SAME: %[[INIT :.*]]: tensor<16x64xf32>
459+ // CHECK-NOT: linalg.reduce { arith.addf } ins(%[[INPUT ]] : tensor<16x32x64xf32>
460+ // CHECK: linalg.reduce ins(%[[INPUT ]] : tensor<16x32x64xf32>) outs(%[[INIT ]] : tensor<16x64xf32>)
461461// CHECK-SAME: dimensions = [1]
462462// CHECK: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
463463// CHECK-NEXT: %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
@@ -620,9 +620,8 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
620620
621621// -----
622622
623- func.func @map_not_short_form_compatible (%arg0: tensor <1 x32 xf32 >, %arg1: tensor <1 x32 xf32 >) -> tensor <1 x32 xf32 > {
624- %res = tensor.empty () : tensor <1 x32 xf32 >
625- %mapped = linalg.map ins (%arg0 , %arg1 : tensor <1 x32 xf32 >, tensor <1 x32 xf32 >) outs (%res : tensor <1 x32 xf32 >)
623+ func.func @map_not_short_form_compatible (%lhs: tensor <1 x32 xf32 >, %rhs: tensor <1 x32 xf32 >, %init: tensor <1 x32 xf32 >) -> tensor <1 x32 xf32 > {
624+ %mapped = linalg.map ins (%lhs , %rhs : tensor <1 x32 xf32 >, tensor <1 x32 xf32 >) outs (%init : tensor <1 x32 xf32 >)
626625 (%in_1: f32 , %in_2: f32 ) {
627626 %1 = arith.maximumf %in_1 , %in_2 : f32
628627 linalg.yield %in_1 : f32
@@ -631,11 +630,10 @@ func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<
631630}
632631
633632// CHECK-LABEL: func @map_not_short_form_compatible
634- // CHECK-SAME: %[[ARG0:.*]]: tensor<1x32xf32>, %[[ARG1:.*]]: tensor<1x32xf32>
635- // CHECK: %[[RES:.*]] = tensor.empty() : tensor<1x32xf32>
636- // CHECK-NOT: linalg.map { arith.maximumf } ins(%[[ARG0]] : tensor<1x32xf32>
637- // CHECK: linalg.map ins(%[[ARG0]], %[[ARG1]] : tensor<1x32xf32>, tensor<1x32xf32>)
638- // CHECK-SAME: outs(%[[RES]] : tensor<1x32xf32>)
633+ // CHECK-SAME: %[[LHS:.*]]: tensor<1x32xf32>, %[[RHS:.*]]: tensor<1x32xf32>, %[[INIT:.*]]: tensor<1x32xf32>
634+ // CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
635+ // CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>)
636+ // CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>)
639637// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
640638// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
641639// CHECK-NEXT: linalg.yield %[[IN1]] : f32
0 commit comments