@@ -57,7 +57,7 @@ func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector
5757// CHECK-LABEL: func @vector_constant_mask_maskedload_i2(
5858// CHECK-SAME: %[[ARG0:.+]]: vector<5xi2>) -> vector<3x5xi2>
5959// CHECK: %[[ORIGINMASK:.+]] = vector.constant_mask [3] : vector<5xi1>
60- // CHECK: %[[NEWMASK:.+]] = arith.constant dense<true> : vector<2xi1>
60+ // CHECK: %[[NEWMASK:.+]] = vector.constant_mask [2] : vector<2xi1>
6161// CHECK: %[[VESSEL:.+]] = arith.constant dense<0> : vector<8xi2>
6262// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[ARG0]], %[[VESSEL]]
6363// CHECK-SAME: {offsets = [2], strides = [1]} : vector<5xi2> into vector<8xi2>
@@ -123,6 +123,47 @@ func.func @check_unaligned_create_mask_static_i2(%passthru: vector<7xi2>) -> vec
123123
124124// -----
125125
126+ func.func @vector_constant_mask_maskedload_i2_multidim (%passthru: vector <5 xi2 >) -> vector <5 xi2 > {
127+ %0 = memref.alloc () : memref <4 x3 x5 xi2 >
128+ %mask = vector.constant_mask [2 , 2 ] : vector <3 x5 xi1 >
129+ %ext_mask = vector.extract %mask [1 ] : vector <5 xi1 > from vector <3 x5 xi1 >
130+ %c0 = arith.constant 0 : index
131+ %c2 = arith.constant 2 : index
132+ %1 = vector.maskedload %0 [%c2 , %c0 , %c0 ], %ext_mask , %passthru :
133+ memref <4 x3 x5 xi2 >, vector <5 xi1 >, vector <5 xi2 > into vector <5 xi2 >
134+ return %1 : vector <5 xi2 >
135+ }
136+
137+ // CHECK-LABEL: func @vector_constant_mask_maskedload_i2_multidim(
138+ // CHECK-SAME: %[[PASSTHRU:[a-zA-Z0-9]+]]
139+ // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<15xi8>
140+ // CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [2, 2] : vector<3x5xi1>
141+ // CHECK: %[[EXT_ORIG_MASK:.+]] = vector.extract %[[ORIG_MASK]][1]
142+
143+ // compressed mask, used for emulated masked load
144+ // CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2, 1] : vector<3x2xi1>
145+ // CHECK: %[[EXT_NEW_MASK:.+]] = vector.extract %[[NEW_MASK]][1]
146+
147+ // Create a padded and shifted passthru vector
148+ // CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
149+ // CHECK: %[[PADDED_PTH:.+]] = vector.insert_strided_slice %[[PASSTHRU]], %[[EMPTY]]
150+ // CHECK-SAME: {offsets = [2], strides = [1]}
151+
152+ // CHECK: %[[PTH_DOWNCAST:.+]] = vector.bitcast %[[PADDED_PTH]] : vector<8xi2> to vector<2xi8>
153+ // CHECK: %[[C7:.+]] = arith.constant 7 : index
154+ // CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C7]]], %[[EXT_NEW_MASK]], %[[PTH_DOWNCAST]]
155+ // CHECK: %[[DOWNCAST_LOAD:.+]] = vector.bitcast %[[MASKLOAD]]
156+
157+ // pad and shift the original mask to match the size and location of the loaded value.
158+ // CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
159+ // CHECK: %[[PADDED_MASK:.+]] = vector.insert_strided_slice %[[EXT_ORIG_MASK]], %[[EMPTY_MASK]]
160+ // CHECK-SAME: {offsets = [2], strides = [1]}
161+ // CHECK: %[[SELECT:.+]] = arith.select %[[PADDED_MASK]], %[[DOWNCAST_LOAD]], %[[PADDED_PTH]]
162+ // CHECK: %[[RESULT:.+]] = vector.extract_strided_slice %[[SELECT]]
163+ // CHECK-SAME: {offsets = [2], sizes = [5], strides = [1]}
164+
165+ // -----
166+
126167func.func @vector_load_i2_dynamic_indexing (%idx1: index , %idx2: index ) -> vector <3 xi2 > {
127168 %0 = memref.alloc () : memref <3 x3 xi2 >
128169 %cst = arith.constant dense <0 > : vector <3 x3 xi2 >
@@ -252,7 +293,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
252293// CHECK: %[[MASK:.+]] = vector.constant_mask [3] : vector<3xi1>
253294// CHECK: %[[LINEAR1:.+]] = affine.apply #map()[%[[IDX]]]
254295// CHECK: %[[LINEAR2:.+]] = affine.apply #map1()[%[[IDX]]]
255- // CHECK: %[[ONE:.+]] = arith.constant dense<true> : vector<2xi1>
296+ // CHECK: %[[ONE:.+]] = vector.constant_mask [2] : vector<2xi1>
256297// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<8xi2>
257298
258299// Extract passthru vector, and insert into zero vector, this is for constructing a new passthru
@@ -301,7 +342,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
301342
302343// -----
303344
304- func.func @vector_maskedload_i4_constant_mask_unaligned (%passthru: vector <5 xi2 >) -> vector <5 xi2 > {
345+ func.func @vector_maskedload_i2_constant_mask_unaligned (%passthru: vector <5 xi2 >) -> vector <5 xi2 > {
305346 %0 = memref.alloc () : memref <3 x5 xi2 >
306347 %mask = arith.constant dense <[false , true , true , true , false ]> : vector <5 xi1 >
307348 %c0 = arith.constant 0 : index
@@ -311,24 +352,23 @@ func.func @vector_maskedload_i4_constant_mask_unaligned(%passthru: vector<5xi2>)
311352 return %1 : vector <5 xi2 >
312353}
313354
314- // CHECK: func @vector_maskedload_i4_constant_mask_unaligned (
355+ // CHECK: func @vector_maskedload_i2_constant_mask_unaligned (
315356// CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
316357// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
317358// CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
318359
360+ // Emulated masked load from alloc:
319361// CHECK: %[[COMPRESSED_MASK:.+]] = arith.constant dense<true> : vector<2xi1>
320362// CHECK: %[[EMPTY:.+]] = arith.constant dense<0> : vector<8xi2>
321363// CHECK: %[[PTH_PADDED:.+]] = vector.insert_strided_slice %[[PTH]], %[[EMPTY]]
322364// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
323-
324- // Emulated masked load from alloc:
325365// CHECK: %[[PTH_PADDED_UPCAST:.+]] = vector.bitcast %[[PTH_PADDED]] : vector<8xi2> to vector<2xi8>
326366// CHECK: %[[C1:.+]] = arith.constant 1 : index
327367// CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[COMPRESSED_MASK]], %[[PTH_PADDED_UPCAST]]
328368// CHECK: %[[MASKLOAD_DOWNCAST:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
329369
330370// Select from emulated loaded vector and passthru vector:
331- // TODO: fold this part if possible.
371+ // TODO: fold insert_strided_slice into source if possible.
332372// CHECK: %[[EMPTY_MASK:.+]] = arith.constant dense<false> : vector<8xi1>
333373// CHECK: %[[MASK_PADDED:.+]] = vector.insert_strided_slice %[[MASK]], %[[EMPTY_MASK]]
334374// CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
0 commit comments