@@ -361,6 +361,74 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
361361/// vector.store
362362///----------------------------------------------------------------------------------------
363363
364+ // -----
365+
366+ // Most basic example to demonstrate where partial stores are not needed.
367+
368+ func.func @vector_store_i2_const_index_no_partial_store (%arg0: vector <4 xi2 >) {
369+ %0 = memref.alloc () : memref <13 xi2 >
370+ %c4 = arith.constant 4 : index
371+ vector.store %arg0 , %0 [%c4 ] : memref <13 xi2 >, vector <4 xi2 >
372+ return
373+ }
374+ // CHECK-LABEL: func.func @vector_store_i2_const_index_no_partial_store(
375+ // CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
376+ // CHECK-NOT: memref.generic_atomic_rmw
377+ // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi8>
378+ // CHECK: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4xi2> to vector<1xi8>
379+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
380+ // CHECK: vector.store %[[UPCAST]], %[[ALLOC]]{{\[}}%[[C1]]] : memref<4xi8>, vector<1xi8>
381+
382+ // -----
383+
384+ // Small modification of the example above to demonstrate where partial stores
385+ // are needed.
386+
387+ func.func @vector_store_i2_const_index_two_partial_stores (%arg0: vector <4 xi2 >) {
388+ %0 = memref.alloc () : memref <13 xi2 >
389+ %c3 = arith.constant 3 : index
390+ vector.store %arg0 , %0 [%c3 ] : memref <13 xi2 >, vector <4 xi2 >
391+ return
392+ }
393+
394+ // CHECK-LABEL: func.func @vector_store_i2_const_index_two_partial_stores(
395+ // CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
396+ // CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<4xi8>
397+
398+ // First atomic RMW:
399+ // CHECK: %[[IDX_1:.*]] = arith.constant 0 : index
400+ // CHECK: %[[MASK_1:.*]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
401+ // CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<4xi2>
402+ // CHECK: %[[SLICE_1:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xi2> to vector<1xi2>
403+ // CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[SLICE_1]], %[[INIT]] {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
404+ // CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_1]]] : memref<4xi8> {
405+ // CHECK: ^bb0(%[[VAL_8:.*]]: i8):
406+ // CHECK: %[[VAL_9:.*]] = vector.from_elements %[[VAL_8]] : vector<1xi8>
407+ // CHECK: %[[DOWNCAST_1:.*]] = vector.bitcast %[[VAL_9]] : vector<1xi8> to vector<4xi2>
408+ // CHECK: %[[SELECT_1:.*]] = arith.select %[[MASK_1]], %[[V1]], %[[DOWNCAST_1]] : vector<4xi1>, vector<4xi2>
409+ // CHECK: %[[UPCAST_1:.*]] = vector.bitcast %[[SELECT_1]] : vector<4xi2> to vector<1xi8>
410+ // CHECK: %[[RES_1:.*]] = vector.extract %[[UPCAST_1]][0] : i8 from vector<1xi8>
411+ // CHECK: memref.atomic_yield %[[RES_1]] : i8
412+ // CHECK: }
413+
414+ // Second atomic RMW:
415+ // CHECK: %[[VAL_14:.*]] = arith.constant 1 : index
416+ // CHECK: %[[IDX_2:.*]] = arith.addi %[[IDX_1]], %[[VAL_14]] : index
417+ // CHECK: %[[VAL_16:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
418+ // CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[INIT]] {offsets = [0], strides = [1]} : vector<3xi2> into vector<4xi2>
419+ // CHECK: %[[MASK_2:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
420+ // CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_2]]] : memref<4xi8> {
421+ // CHECK: ^bb0(%[[VAL_20:.*]]: i8):
422+ // CHECK: %[[VAL_21:.*]] = vector.from_elements %[[VAL_20]] : vector<1xi8>
423+ // CHECK: %[[DONWCAST_2:.*]] = vector.bitcast %[[VAL_21]] : vector<1xi8> to vector<4xi2>
424+ // CHECK: %[[SELECT_2:.*]] = arith.select %[[MASK_2]], %[[V2]], %[[DONWCAST_2]] : vector<4xi1>, vector<4xi2>
425+ // CHECK: %[[UPCAST_2:.*]] = vector.bitcast %[[SELECT_2]] : vector<4xi2> to vector<1xi8>
426+ // CHECK: %[[RES_2:.*]] = vector.extract %[[UPCAST_2]][0] : i8 from vector<1xi8>
427+ // CHECK: memref.atomic_yield %[[RES_2]] : i8
428+ // CHECK: }
429+
430+ // -----
431+
364432func.func @vector_store_i2_const_index_two_partial_stores (%arg0: vector <3 xi2 >) {
365433 %src = memref.alloc () : memref <3 x3 xi2 >
366434 %c0 = arith.constant 0 : index
0 commit comments