@@ -66,7 +66,7 @@ module attributes {transform.with_named_sequence} {
66
66
// -----
67
67
68
68
#map = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
69
- func.func @vectorize_nd_tensor_extract_constant_idx (%arg0: tensor <3 x3 xf32 >, %arg2: tensor <1 x1 x3 xf32 >) -> tensor <1 x1 x3 xf32 > {
69
+ func.func @vectorize_nd_tensor_extract_scalar_broadcast (%arg0: tensor <3 x3 xf32 >, %arg2: tensor <1 x1 x3 xf32 >) -> tensor <1 x1 x3 xf32 > {
70
70
%c0 = arith.constant 1 : index
71
71
%c1 = arith.constant 2 : index
72
72
%2 = linalg.generic {
@@ -80,17 +80,17 @@ func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg
80
80
return %2 : tensor <1 x1 x3 xf32 >
81
81
}
82
82
83
- // CHECK: #[[$MAP:.* ]] = affine_map<(d0, d1) -> (0, 0, 0)>
84
- // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_constant_idx (
83
+ // CHECK: #[[$MAP:.+ ]] = affine_map<(d0, d1) -> (0, 0, 0)>
84
+ // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_scalar_broadcast (
85
85
// CHECK-SAME: %[[ARG_0:.*]]: tensor<3x3xf32>,
86
86
// CHECK-SAME: %[[ARG_1:.*]]: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
87
87
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
88
88
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
89
- // CHECK-DAG: %[[C0_f32_2 :.*]] = arith.constant 0.000000e+00 : f32
90
- // CHECK-DAG : %[[C0_f32 :.*]] = arith.constant 0.000000e+00 : f32
91
- // CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG_0]][%[[C1]], %[[C2]]], %[[C0_f32]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<3x3xf32>, vector<1x1x3xf32>
92
- // CHECK: %[[C0_4 :.*]] = arith.constant 0 : index
93
- // CHECK: vector.transfer_write %[[READ]], %[[ARG_1]][ %[[C0_4 ]], %[[C0_4 ]], %[[C0_4 ]]] : vector<1x1x3xf32>, tensor<1x1x3xf32>
89
+ // CHECK-DAG: %[[C0 :.*]] = arith.constant 0 : index
90
+ // CHECK: %[[MASK :.*]] = vector.constant_mask [1] : vector<1xi1>
91
+ // CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C1]], %[[C2]]], {{.*}} {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<3x3xf32>, vector<1x1x3xf32> } : vector<1xi1> -> vector<1x1x3xf32>
92
+ // CHECK: %[[C0_2 :.*]] = arith.constant 0 : index
93
+ // CHECK: vector.transfer_write %[[READ]], %[[ARG_1]]{{\[}} %[[C0_2 ]], %[[C0_2 ]], %[[C0_2 ]]] : vector<1x1x3xf32>, tensor<1x1x3xf32>
94
94
95
95
module attributes {transform.with_named_sequence } {
96
96
transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
@@ -823,7 +823,7 @@ func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> t
823
823
return %out:tensor <1 x1 x4 xi32 >
824
824
}
825
825
826
- // CHECK: #[[$ATTR_1 :.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
826
+ // CHECK: #[[$MAP :.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
827
827
// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
828
828
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
829
829
// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
@@ -844,12 +844,14 @@ func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> t
844
844
// CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
845
845
// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
846
846
// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
847
- // CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
848
- // CHECK: %[[VAL_21:.*]] = vector.extract %[[VAL_20]][0] : index from vector<4xindex>
849
- // CHECK: %[[VAL_22:.*]] = arith.constant 0 : i32
850
- // CHECK: %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
847
+ // CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
848
+ // CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_19]][0] : index from vector<4xindex>
849
+ // CHECK: %[[VAL_21:.*]] = arith.constant 0 : i32
850
+ // CHECK: %[[VAL_22:.*]] = vector.constant_mask [1] : vector<1xi1>
851
+ // CHECK: %[[VAL_23:.*]] = vector.mask %[[VAL_22]] { vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_20]], %[[VAL_2]]], %[[VAL_21]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<15x1xi32>, vector<1x1x4xi32> } : vector<1xi1> -> vector<1x1x4xi32>
851
852
// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
852
853
// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
854
+ // CHECK: return %[[VAL_25]] : tensor<1x1x4xi32>
853
855
854
856
module attributes {transform.with_named_sequence } {
855
857
transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
0 commit comments