11// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
22// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
33
4+ ///----------------------------------------------------------------------------------------
5+ /// vector.load
6+ ///----------------------------------------------------------------------------------------
7+
48func.func @vector_load_i8 (%arg1: index , %arg2: index ) -> vector <4 xi8 > {
59 %0 = memref.alloc () : memref <3 x4 xi8 >
610 %1 = vector.load %0 [%arg1 , %arg2 ] : memref <3 x4 xi8 >, vector <4 xi8 >
@@ -82,6 +86,10 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
8286
8387// -----
8488
89+ ///----------------------------------------------------------------------------------------
90+ /// vector.transfer_read
91+ ///----------------------------------------------------------------------------------------
92+
8593func.func @vector_transfer_read_i4 (%arg1: index , %arg2: index ) -> vector <8 xi4 > {
8694 %c0 = arith.constant 0 : i4
8795 %0 = memref.alloc () : memref <3 x8 xi4 >
@@ -111,6 +119,10 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
111119
112120// -----
113121
122+ ///----------------------------------------------------------------------------------------
123+ /// vector.maskedload
124+ ///----------------------------------------------------------------------------------------
125+
114126func.func @vector_maskedload_i8 (%arg1: index , %arg2: index , %arg3: index , %passthru: vector <4 xi8 >) -> vector <4 xi8 > {
115127 %0 = memref.alloc () : memref <3 x4 xi8 >
116128 %mask = vector.create_mask %arg3 : vector <4 xi1 >
@@ -190,15 +202,15 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt
190202
191203// -----
192204
193- func.func @vector_cst_maskedload_i8 (%arg1: index , %arg2: index , %passthru: vector <4 xi8 >) -> vector <4 xi8 > {
205+ func.func @vector_maskedload_i8_constant_mask (%arg1: index , %arg2: index , %passthru: vector <4 xi8 >) -> vector <4 xi8 > {
194206 %0 = memref.alloc () : memref <3 x4 xi8 >
195207 %mask = vector.constant_mask [2 ] : vector <4 xi1 >
196208 %1 = vector.maskedload %0 [%arg1 , %arg2 ], %mask , %passthru :
197209 memref <3 x4 xi8 >, vector <4 xi1 >, vector <4 xi8 > into vector <4 xi8 >
198210 return %1 : vector <4 xi8 >
199211}
200212// Expect no conversions, i8 is supported.
201- // CHECK: func @vector_cst_maskedload_i8 (
213+ // CHECK: func @vector_maskedload_i8_constant_mask (
202214// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
203215// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<4xi8>)
204216// CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8>
@@ -208,7 +220,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
208220// CHECK-NEXT: return
209221
210222// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)>
211- // CHECK32: func @vector_cst_maskedload_i8 (
223+ // CHECK32: func @vector_maskedload_i8_constant_mask (
212224// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
213225// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>)
214226// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -224,7 +236,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto
224236
225237// -----
226238
227- func.func @vector_cst_maskedload_i4 (%arg1: index , %arg2: index , %passthru: vector <8 xi4 >) -> vector <3 x8 xi4 > {
239+ func.func @vector_maskedload_i4_constant_mask (%arg1: index , %arg2: index , %passthru: vector <8 xi4 >) -> vector <3 x8 xi4 > {
228240 %0 = memref.alloc () : memref <3 x8 xi4 >
229241 %cst = arith.constant dense <0 > : vector <3 x8 xi4 >
230242 %mask = vector.constant_mask [4 ] : vector <8 xi1 >
@@ -234,7 +246,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
234246 return %2 : vector <3 x8 xi4 >
235247}
236248// CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
237- // CHECK: func @vector_cst_maskedload_i4 (
249+ // CHECK: func @vector_maskedload_i4_constant_mask (
238250// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
239251// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
240252// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
@@ -248,7 +260,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
248260// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4>
249261
250262// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
251- // CHECK32: func @vector_cst_maskedload_i4 (
263+ // CHECK32: func @vector_maskedload_i4_constant_mask (
252264// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index,
253265// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>)
254266// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
@@ -263,6 +275,10 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
263275
264276// -----
265277
278+ ///----------------------------------------------------------------------------------------
279+ /// vector.extract -> vector.masked_load
280+ ///----------------------------------------------------------------------------------------
281+
266282func.func @vector_extract_maskedload_i4 (%arg1: index ) -> vector <8 x8 x16 xi4 > {
267283 %0 = memref.alloc () : memref <8 x8 x16 xi4 >
268284 %c0 = arith.constant 0 : index
@@ -353,6 +369,10 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
353369
354370// -----
355371
372+ ///----------------------------------------------------------------------------------------
373+ /// vector.store
374+ ///----------------------------------------------------------------------------------------
375+
356376func.func @vector_store_i8 (%arg0: vector <8 xi8 >, %arg1: index , %arg2: index ) {
357377 %0 = memref.alloc () : memref <4 x8 xi8 >
358378 vector.store %arg0 , %0 [%arg1 , %arg2 ] :memref <4 x8 xi8 >, vector <8 xi8 >
@@ -431,6 +451,10 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
431451
432452// -----
433453
454+ ///----------------------------------------------------------------------------------------
455+ /// vector.maskedstore
456+ ///----------------------------------------------------------------------------------------
457+
434458func.func @vector_maskedstore_i8 (%arg0: index , %arg1: index , %arg2: index , %value: vector <8 xi8 >) {
435459 %0 = memref.alloc () : memref <3 x8 xi8 >
436460 %mask = vector.create_mask %arg2 : vector <8 xi1 >
@@ -469,14 +493,68 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
469493
470494// -----
471495
472- func.func @vector_cst_maskedstore_i8 (%arg0: index , %arg1: index , %value: vector <8 xi8 >) {
496+ func.func @vector_maskedstore_i4 (
497+ %idx1: index ,
498+ %idx2: index ,
499+ %num_elements_to_store: index ,
500+ %value: vector <8 xi4 >) {
501+
502+ %0 = memref.alloc () : memref <3 x8 xi4 >
503+ %mask = vector.create_mask %num_elements_to_store : vector <8 xi1 >
504+ vector.maskedstore %0 [%idx1 , %idx2 ], %mask , %value :
505+ memref <3 x8 xi4 >, vector <8 xi1 >, vector <8 xi4 >
506+ return
507+ }
508+ // CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
509+ // CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
510+
511+ // CHECK-LABEL: func.func @vector_maskedstore_i4(
512+ // CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
513+ // CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
514+ // CHECK-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
515+ // CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
516+ // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
517+ // CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
518+ // CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]]()[%[[IDX_1]], %[[IDX_2]]]
519+ // CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]]()[%[[NUM_EL_TO_STORE]]]
520+ // CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
521+ // CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
522+ // CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
523+ // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
524+ // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
525+ // CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
526+ // CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
527+
528+ // CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
529+ // CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
530+
531+ // CHECK32-LABEL: func.func @vector_maskedstore_i4(
532+ // CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
533+ // CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
534+ // CHECK32-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
535+ // CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
536+ // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
537+ // CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
538+ // CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_17]]()[%[[IDX_1]], %[[IDX_2]]]
539+ // CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]]()[%[[NUM_EL_TO_STORE]]]
540+ // CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
541+ // CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
542+ // CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
543+ // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
544+ // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
545+ // CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
546+ // CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
547+
548+ // -----
549+
550+ func.func @vector_maskedstore_i8_constant_mask (%arg0: index , %arg1: index , %value: vector <8 xi8 >) {
473551 %0 = memref.alloc () : memref <3 x8 xi8 >
474552 %mask = vector.constant_mask [4 ] : vector <8 xi1 >
475553 vector.maskedstore %0 [%arg0 , %arg1 ], %mask , %value : memref <3 x8 xi8 >, vector <8 xi1 >, vector <8 xi8 >
476554 return
477555}
478556// Expect no conversions, i8 is supported.
479- // CHECK: func @vector_cst_maskedstore_i8 (
557+ // CHECK: func @vector_maskedstore_i8_constant_mask (
480558// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
481559// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
482560// CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]]
@@ -486,7 +564,7 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
486564// CHECK-NEXT: return
487565
488566// CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
489- // CHECK32: func @vector_cst_maskedstore_i8 (
567+ // CHECK32: func @vector_maskedstore_i8_constant_mask (
490568// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]
491569// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]
492570// CHECK32-SAME: %[[VAL:[a-zA-Z0-9]+]]
@@ -500,3 +578,49 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
500578// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
501579// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
502580// CHECK32: vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
581+
582+ // -----
583+
584+ func.func @vector_maskedstore_i4_constant_mask (
585+ %idx_1: index ,
586+ %idx_2: index ,
587+ %val_to_store: vector <8 xi4 >) {
588+
589+ %0 = memref.alloc () : memref <3 x8 xi4 >
590+ %mask = vector.constant_mask [4 ] : vector <8 xi1 >
591+ vector.maskedstore %0 [%idx_1 , %idx_2 ], %mask , %val_to_store :
592+ memref <3 x8 xi4 >, vector <8 xi1 >, vector <8 xi4 >
593+ return
594+ }
595+
596+ // CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
597+ // CHECK-LABEL: func.func @vector_maskedstore_i4_constant_mask(
598+ // CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
599+ // CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
600+ // CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
601+ // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
602+ // CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
603+ // CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_12]]()[%[[IDX_1]], %[[IDX_2]]]
604+ // CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
605+ // CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
606+ // CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
607+ // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
608+ // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
609+ // CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
610+ // CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
611+
612+ // CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
613+ // CHECK32-LABEL: func.func @vector_maskedstore_i4_constant_mask(
614+ // CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
615+ // CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
616+ // CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
617+ // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
618+ // CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
619+ // CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_20]]()[%[[IDX_1]], %[[IDX_2]]]
620+ // CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
621+ // CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
622+ // CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
623+ // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
624+ // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
625+ // CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
626+ // CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
0 commit comments