@@ -249,3 +249,174 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
249249// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
250250// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
251251// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
252+
253+ // -----
254+
255+ func.func @vector_store_i2_const (%arg0: vector <3 xi2 >) {
256+ %0 = memref.alloc () : memref <3 x3 xi2 >
257+ %c0 = arith.constant 0 : index
258+ %c2 = arith.constant 2 : index
259+ vector.store %arg0 , %0 [%c2 , %c0 ] :memref <3 x3 xi2 >, vector <3 xi2 >
260+ return
261+ }
262+
263+ // in this example, emit 2 atomic stores, with the first storing 1 element and the second storing 2 elements.
264+ // CHECK: func @vector_store_i2_const(
265+ // CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
266+ // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
267+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
268+
269+ // atomic store of the first byte
270+ // CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
271+ // CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
272+ // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
273+ // CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
274+ // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
275+ // CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
276+ // CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> {
277+ // CHECK: %[[ARG:.+]]: i8):
278+ // CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
279+ // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
280+ // CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
281+ // CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
282+ // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
283+ // CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
284+
285+ // atomic store of the second byte
286+ // CHECK: %[[ADDI:.+]] = arith.addi %[[C1]], %[[C1]] : index
287+ // CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
288+ // CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
289+ // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
290+ // CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
291+ // CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDI]]] : memref<3xi8> {
292+ // CHECK: %[[ARG2:.+]]: i8):
293+ // CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
294+ // CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
295+ // CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST3]] : vector<4xi1>, vector<4xi2>
296+ // CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
297+ // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST4]][0] : i8 from vector<1xi8>
298+ // CHECK: memref.atomic_yield %[[EXTRACT3]] : i8
299+
300+ // -----
301+
302+ func.func @vector_store_i8_2 (%arg0: vector <7 xi2 >) {
303+ %0 = memref.alloc () : memref <3 x7 xi2 >
304+ %c0 = arith.constant 0 : index
305+ %c1 = arith.constant 1 : index
306+ vector.store %arg0 , %0 [%c1 , %c0 ] :memref <3 x7 xi2 >, vector <7 xi2 >
307+ return
308+ }
309+
310+ // in this example, emit 2 atomic stores and 1 non-atomic store
311+
312+ // CHECK: func @vector_store_i8_2(
313+ // CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
314+ // CHECK: %[[ALLOC]] = memref.alloc() : memref<6xi8>
315+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
316+ // CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
317+ // CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
318+
319+ // first atomic store
320+ // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
321+ // CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2>
322+ // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
323+ // CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
324+ // CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> {
325+ // CHECK: %[[ARG:.+]]: i8):
326+ // CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
327+ // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
328+ // CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
329+ // CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
330+ // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
331+ // CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
332+
333+ // non atomic store part
334+ // CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index
335+ // CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
336+ // CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2>
337+ // CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8>
338+ // CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8>
339+
340+ // second atomic store
341+ // CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index
342+ // CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
343+ // CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2>
344+ // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]]
345+ // CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2>
346+ // CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> {
347+ // CHECK: %[[ARG2:.+]]: i8):
348+ // CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
349+ // CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
350+ // CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] :
351+ // CHECK-SAME: vector<4xi1>, vector<4xi2>
352+ // CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
353+ // CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
354+ // CHECK: memref.atomic_yield %[[EXTRACT4]] : i8
355+
356+ // -----
357+
358+ func.func @vector_store_i2_single_atomic (%arg0: vector <1 xi2 >) {
359+ %0 = memref.alloc () : memref <4 x1 xi2 >
360+ %c0 = arith.constant 0 : index
361+ %c1 = arith.constant 1 : index
362+ vector.store %arg0 , %0 [%c1 , %c0 ] :memref <4 x1 xi2 >, vector <1 xi2 >
363+ return
364+ }
365+
366+ // in this example, only emit 1 atomic store
367+ // CHECK: func @vector_store_i2_single_atomic(
368+ // CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
369+ // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
370+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
371+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
372+ // CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]> : vector<4xi1>
373+ // CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
374+ // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
375+ // CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
376+
377+ // CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C0]]] : memref<1xi8> {
378+ // CHECK: %[[ARG:.+]]: i8):
379+ // CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
380+ // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
381+ // CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
382+ // CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
383+ // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
384+ // CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
385+
386+ // -----
387+
388+ func.func @vector_maskedload_i4_constant_mask_unaligned (%passthru: vector <5 xi2 >) -> vector <5 xi2 > {
389+ %0 = memref.alloc () : memref <3 x5 xi2 >
390+ %mask = arith.constant dense <[false , true , true , true , false ]> : vector <5 xi1 >
391+ %c0 = arith.constant 0 : index
392+ %c1 = arith.constant 1 : index
393+ %1 = vector.maskedload %0 [%c1 , %c0 ], %mask , %passthru :
394+ memref <3 x5 xi2 >, vector <5 xi1 >, vector <5 xi2 > into vector <5 xi2 >
395+ return %1 : vector <5 xi2 >
396+ }
397+
398+ // CHECK: func @vector_maskedload_i4_constant_mask_unaligned(
399+ // CHECK-SAME: %[[PTH:.+]]: vector<5xi2>) -> vector<5xi2>
400+ // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4xi8>
401+ // CHECK: %[[MASK:.+]] = arith.constant dense<[false, true, true, true, false]> : vector<5xi1>
402+
403+ // CHECK: %[[CST0:.+]] = arith.constant dense<true> : vector<2xi1>
404+ // CHECK: %[[CST1:.+]] = arith.constant dense<0> : vector<8xi2>
405+ // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[PTH]], %[[CST1]]
406+ // CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi2> into vector<8xi2>
407+
408+ // Emulated masked load from alloc:
409+ // CHECK: %[[BCAST:.+]] = vector.bitcast %[[INSERT]] : vector<8xi2> to vector<2xi8>
410+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
411+ // CHECK: %[[MASKLOAD:.+]] = vector.maskedload %[[ALLOC]][%[[C1]]], %[[CST0]], %[[BCAST]]
412+ // CHECK: %[[BCAST2:.+]] = vector.bitcast %[[MASKLOAD]] : vector<2xi8> to vector<8xi2>
413+
414+ // Select from emulated loaded vector and passthru vector:
415+ // TODO: fold this part if possible.
416+ // CHECK: %[[CST2:.+]] = arith.constant dense<false> : vector<8xi1>
417+ // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[BCAST2]], %[[CST2]]
418+ // CHECK-SAME: {offsets = [1], strides = [1]} : vector<5xi1> into vector<8xi1>
419+ // CHECK: %[[SELECT:.+]] = arith.select %[[INSERT2]], %[[BCAST2]], %[[INSERT]] : vector<8xi1>, vector<8xi2>
420+ // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SELECT]]
421+ // CHECK-SAME: {offsets = [1], sizes = [5], strides = [1]} : vector<8xi2> to vector<5xi2>
422+ // CHECK: return %[[EXTRACT]] : vector<5xi2>
0 commit comments