@@ -57,13 +57,14 @@ tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){
5757// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1x16xf32>,
5858// CHECK-SAME: %[[ARG1:.*]]: tensor<2x1x16xf16>
5959tt.func @reduce (%arg0: tensor <2 x1 x16 xf32 >, %arg1: tensor <2 x1 x16 xf16 >) -> (tensor <2 x16 xf32 >, tensor <2 x16 xf16 >) {
60- // CHECK: tt.reshape %[[ARG0]] allow_reorder : tensor<2x1x16xf32> -> tensor<2x16xf32>
61- // CHECK: tt.reshape %[[ARG1]] allow_reorder : tensor<2x1x16xf16> -> tensor<2x16xf16>
60+ // CHECK: %[[VAL0:.*]] = tt.reshape %[[ARG0]] allow_reorder : tensor<2x1x16xf32> -> tensor<2x16xf32>
61+ // CHECK: %[[VAL1:.*]] = tt.reshape %[[ARG1]] allow_reorder : tensor<2x1x16xf16> -> tensor<2x16xf16>
6262 %0:2 = " tt.reduce" (%arg0 , %arg1 ) <{axis =1 : i32 }> ({
6363 ^bb0 (%acc0: f32 , %acc1: f16 , %curr0: f32 , %curr1: f16 ):
6464 %1 = arith.addf %acc0 , %curr0 : f32
6565 %2 = arith.mulf %acc1 , %curr1 : f16
6666 tt.reduce.return %1 , %2 : f32 , f16
6767 }) : (tensor <2 x1 x16 xf32 >, tensor <2 x1 x16 xf16 >) -> (tensor <2 x16 xf32 >, tensor <2 x16 xf16 >)
68+ // CHECK: tt.return %[[VAL0]], %[[VAL1]] : tensor<2x16xf32>, tensor<2x16xf16>
6869 tt.return %0#0 , %0#1 : tensor <2 x16 xf32 >, tensor <2 x16 xf16 >
6970}
0 commit comments