@@ -341,15 +341,15 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
341341func.func @map_no_inputs (%init: tensor <64 xf32 >) -> tensor <64 xf32 > {
342342 %add = linalg.map
343343 outs (%init:tensor <64 xf32 >)
344- () {
344+ (%out: f32 ) {
345345 %0 = arith.constant 0.0 : f32
346346 linalg.yield %0: f32
347347 }
348348 func.return %add : tensor <64 xf32 >
349349}
350350// CHECK-LABEL: func @map_no_inputs
351351// CHECK: linalg.map outs
352- // CHECK-NEXT: () {
352+ // CHECK-NEXT: (%[[OUT:.*]]: f32 ) {
353353// CHECK-NEXT: arith.constant
354354// CHECK-NEXT: linalg.yield
355355// CHECK-NEXT: }
@@ -361,7 +361,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
361361 %add = linalg.map
362362 ins (%lhs , %rhs: tensor <64 xf32 >, tensor <64 xf32 >)
363363 outs (%init:tensor <64 xf32 >)
364- (%lhs_elem: f32 , %rhs_elem: f32 ) {
364+ (%lhs_elem: f32 , %rhs_elem: f32 , %out: f32 ) {
365365 %0 = arith.addf %lhs_elem , %rhs_elem: f32
366366 linalg.yield %0: f32
367367 }
@@ -378,7 +378,7 @@ func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
378378 linalg.map
379379 ins (%lhs , %rhs: memref <64 xf32 >, memref <64 xf32 >)
380380 outs (%init:memref <64 xf32 >)
381- (%lhs_elem: f32 , %rhs_elem: f32 ) {
381+ (%lhs_elem: f32 , %rhs_elem: f32 , %out: f32 ) {
382382 %0 = arith.addf %lhs_elem , %rhs_elem: f32
383383 linalg.yield %0: f32
384384 }
@@ -393,7 +393,7 @@ func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64
393393 %abs = linalg.map
394394 ins (%input:tensor <64 xf32 >)
395395 outs (%init:tensor <64 xf32 >)
396- (%input_elem: f32 ) {
396+ (%input_elem: f32 , %out: f32 ) {
397397 %0 = math.absf %input_elem: f32
398398 linalg.yield %0: f32
399399 }
@@ -408,7 +408,7 @@ func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
408408 linalg.map
409409 ins (%input:memref <64 xf32 >)
410410 outs (%init:memref <64 xf32 >)
411- (%input_elem: f32 ) {
411+ (%input_elem: f32 , %out: f32 ) {
412412 %0 = math.absf %input_elem: f32
413413 linalg.yield %0: f32
414414 }
@@ -604,7 +604,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
604604 %add = linalg.map
605605 ins (%lhs , %rhs: tensor <64 xf32 >, tensor <64 xf32 >)
606606 outs (%init:tensor <64 xf32 >)
607- (%lhs_elem: f32 , %rhs_elem: f32 ) {
607+ (%lhs_elem: f32 , %rhs_elem: f32 , %out: f32 ) {
608608 %0 = arith.addf %lhs_elem , %rhs_elem fastmath <fast > : f32
609609 linalg.yield %0: f32
610610 }
@@ -622,7 +622,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
622622
623623func.func @map_not_short_form_compatible (%lhs: tensor <1 x32 xf32 >, %rhs: tensor <1 x32 xf32 >, %init: tensor <1 x32 xf32 >) -> tensor <1 x32 xf32 > {
624624 %mapped = linalg.map ins (%lhs , %rhs : tensor <1 x32 xf32 >, tensor <1 x32 xf32 >) outs (%init : tensor <1 x32 xf32 >)
625- (%in_1: f32 , %in_2: f32 ) {
625+ (%in_1: f32 , %in_2: f32 , %out: f32 ) {
626626 %1 = arith.maximumf %in_1 , %in_2 : f32
627627 linalg.yield %in_1 : f32
628628 }
@@ -634,7 +634,7 @@ func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x
634634// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
635635// CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>)
636636// CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>)
637- // CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
637+ // CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32 ) {
638638// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
639639// CHECK-NEXT: linalg.yield %[[IN1]] : f32
640640// CHECK-NEXT: }
0 commit comments