@@ -2874,6 +2874,22 @@ func.func @contiguous_gather_const_mask(%base: memref<?xf32>,
28742874
28752875// -----
28762876
2877+ // CHECK-LABEL: @contiguous_gather_step
2878+ // CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
2879+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
2880+ // CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2881+ // CHECK: return %[[R]]
2882+ func.func @contiguous_gather_step (%base: memref <?xf32 >,
2883+ %mask: vector <16 xi1 >, %passthru: vector <16 xf32 >) -> vector <16 xf32 > {
2884+ %c0 = arith.constant 0 : index
2885+ %indices = vector.step : vector <16 xindex >
2886+ %1 = vector.gather %base [%c0 ][%indices ], %mask , %passthru :
2887+ memref <?xf32 >, vector <16 xindex >, vector <16 xi1 >, vector <16 xf32 > into vector <16 xf32 >
2888+ return %1 : vector <16 xf32 >
2889+ }
2890+
2891+ // -----
2892+
28772893// CHECK-LABEL: @contiguous_scatter
28782894// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
28792895// CHECK: %[[C0:.*]] = arith.constant 0 : index
@@ -2902,3 +2918,18 @@ func.func @contiguous_scatter_const_mask(%base: memref<?xf32>,
29022918 memref <?xf32 >, vector <16 xi32 >, vector <16 xi1 >, vector <16 xf32 >
29032919 return
29042920}
2921+
2922+ // -----
2923+
2924+ // CHECK-LABEL: @contiguous_scatter_step
2925+ // CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
2926+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
2927+ // CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
2928+ func.func @contiguous_scatter_step (%base: memref <?xf32 >,
2929+ %mask: vector <16 xi1 >, %value: vector <16 xf32 >) {
2930+ %c0 = arith.constant 0 : index
2931+ %indices = vector.step : vector <16 xindex >
2932+ vector.scatter %base [%c0 ][%indices ], %mask , %value :
2933+ memref <?xf32 >, vector <16 xindex >, vector <16 xi1 >, vector <16 xf32 >
2934+ return
2935+ }
0 commit comments