@@ -57,7 +57,7 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
5757// CHECK-LABEL: func @vector_cst_maskedload_i2(
5858// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
5959// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
60- // CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
60+ // CHECK: %[[NEWMASK:.+]] = vector.constant_mask [2] : vector<2xi1>
6161// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
6262// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]]
6363// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
@@ -74,6 +74,48 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
7474
7575// -----
7676
77+ func.func @vector_constant_mask_maskedload_i2_multidim (%passthru: vector <5 xi2 >) -> vector <5 xi2 > {
78+ %0 = memref.alloc () : memref <4 x3 x5 xi2 >
79+ %cst = arith.constant dense <0 > : vector <3 x5 xi2 >
80+ %mask = vector.constant_mask [2 , 2 ] : vector <3 x5 xi1 >
81+ %ext_mask = vector.extract %mask [1 ] : vector <5 xi1 > from vector <3 x5 xi1 >
82+ %c0 = arith.constant 0 : index
83+ %c2 = arith.constant 2 : index
84+ %1 = vector.maskedload %0 [%c2 , %c0 , %c0 ], %ext_mask , %passthru :
85+ memref <4 x3 x5 xi2 >, vector <5 xi1 >, vector <5 xi2 > into vector <5 xi2 >
86+ return %1 : vector <5 xi2 >
87+ }
88+
89+ // CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim(
90+ // CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]
91+ // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<15xi8>
92+ // CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1>
93+ // CHECK: %[[EXT_ORIG_MASK:.+]] = vector.extract %[[ORIG_MASK]][1]
94+
95+ // compressed mask, used for emulated masked load
96+ // CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
97+ // CHECK: %[[EXT_NEW_MASK:.+]] = vector.extract %[[NEW_MASK]][1]
98+
99+ // Create a padded and shifted passthru vector
100+ // CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
101+ // CHECK: %[[PADDED_PTH:.+]] = vector.insert_strided_slice %[[PASSTHRU]], %[[EMPTY]]
102+ // CHECK-SAME: {offsets = [2], strides = [1]}
103+
104+ // CHECK: %[[PTH_DOWNCAST:.+]] = vector.bitcast %[[PADDED_PTH]] : vector<8xi2> to vector<2xi8>
105+ // CHECK: %[[C7:.+]] = arith.constant 7 : index
106+ // CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C7]]], %[[EXT_NEW_MASK]], %[[PTH_DOWNCAST]]
107+ // CHECK: %[[DOWNCAST_LOAD:.+]] = vector.bitcast %[[MASKLOAD]]
108+
109+ // pad and shift the original mask to match the size and location of the loaded value.
110+ // CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
111+ // CHECK: %[[PADDED_MASK:.+]] = vector.insert_strided_slice %[[EXT_ORIG_MASK]], %[[EMPTY_MASK]]
112+ // CHECK-SAME: {offsets = [2], strides = [1]}
113+ // CHECK: %[[SELECT:.+]] = arith.select %[[PADDED_MASK]], %[[DOWNCAST_LOAD]], %[[PADDED_PTH]]
114+ // CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
115+ // CHECK-SAME: {offsets = [2], sizes = [5], strides = [1]}
116+
117+ // -----
118+
77119func.func @vector_load_i2_dynamic_indexing (%idx1: index , %idx2: index ) -> vector <3 xi2 > {
78120 %0 = memref.alloc () : memref <3 x3 xi2 >
79121 %cst = arith.constant dense <0 > : vector <3 x3 xi2 >
@@ -203,7 +245,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
203245// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
204246// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
205247// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
206- // CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
248+ // CHECK: %[[ONE:.+]] = vector.constant_mask [2] : vector<2xi1>
207249// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
208250
209251// Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
@@ -268,18 +310,17 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
268310// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
269311
270312// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<true> : vector<2xi1>
313+
314+ // Emulated masked load from alloc:
271315// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
272316// CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]]
273317// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
274-
275- // Emulated masked load from alloc:
276318// CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8>
277319// CHECK: %[[C1:.+]] = arith.constant 1 : index
278320// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]]
279321// CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
280322
281- // Select from emulated loaded vector and passthru vector:
282- // TODO: fold this part if possible.
323+ // Select from emulated loaded vector and passthru vector: (TODO: fold this part if possible)
283324// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
284325// CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]]
285326// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
0 commit comments