@@ -185,3 +185,67 @@ gpu.func @gather_from_subview(%source: memref<4096x4096xf16>,
185185// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS]] : vector<8xi1>, vector<8xf16>
186186// CHECK: gpu.return %[[RES]] : vector<8xf16>
187187}
188+
189+ // -----
190+ gpu.module @xevm_module {
191+ gpu.func @non_unit_inner_stride_1D (
192+ %source: memref <32 xf32 , strided <[?], offset : ?>>,
193+ %off: index , %indices: vector <8 xindex >, %mask: vector <8 xi1 >,
194+ %pass_thru: vector <8 xf32 >) -> vector <8 xf32 > {
195+ %0 = vector.gather %source [%off ][%indices ], %mask , %pass_thru
196+ : memref <32 xf32 , strided <[?], offset : ?>>,
197+ vector <8 xindex >, vector <8 xi1 >, vector <8 xf32 >
198+ into vector <8 xf32 >
199+ gpu.return %0 : vector <8 xf32 >
200+ }
201+ // CHECK-LABEL: @non_unit_inner_stride_1D(
202+ // CHECK-SAME: %[[SRC:.+]]: memref<32xf32, strided<[?], offset: ?>>,
203+ // CHECK-SAME: %[[OFF1:.+]]: index,
204+ // CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>,
205+ // CHECK-SAME: %[[MASK:.+]]: vector<8xi1>, %[[PASS:.+]]: vector<8xf32>) -> vector<8xf32> {
206+ // CHECK: %[[BB:.+]], %[[M_OFF:.+]], %[[SZ:.+]], %[[STRIDE:.+]] = memref.extract_strided_metadata %[[SRC]]
207+ // CHECK: arith.muli %[[OFF1]], %[[STRIDE]] : index
208+ // CHECK: arith.addi {{.*}} : index
209+ // CHECK: %[[STRD_VEC:.+]] = vector.broadcast %[[STRIDE]] : index to vector<8xindex>
210+ // CHECK: %[[STRD_INDICES:.+]] = arith.muli %[[STRD_VEC:.+]], %[[INDICES]] : vector<8xindex>
211+ // CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
212+ // CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex>
213+ // CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<32xf32, strided<[?], offset: ?>> -> index
214+ // CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
215+ // CHECK: %[[V:.+]] = xegpu.load %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
216+ // CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[V]], %[[PASS]] : vector<8xi1>, vector<8xf32>
217+ // CHECK: gpu.return %[[RES]] : vector<8xf32>
218+ }
219+
220+ // -----
221+ gpu.module @xevm_module {
222+ gpu.func @non_unit_inner_stride_3D (
223+ %source: memref <4 x8 x32 xf32 , strided <[?, 128 , 2 ], offset : ?>>,
224+ %off0: index , %off1: index , %off2: index ,
225+ %indices: vector <8 xindex >, %mask: vector <8 xi1 >,
226+ %pass_thru: vector <8 xf32 >) -> vector <8 xf32 > {
227+ %0 = vector.gather %source [%off0 , %off1 , %off2 ][%indices ], %mask , %pass_thru
228+ : memref <4 x8 x32 xf32 , strided <[?, 128 , 2 ], offset : ?>>,
229+ vector <8 xindex >, vector <8 xi1 >, vector <8 xf32 >
230+ into vector <8 xf32 >
231+ gpu.return %0 : vector <8 xf32 >
232+ }
233+ // CHECK-LABEL: @non_unit_inner_stride_3D(
234+ // CHECK-SAME: %[[SRC:.+]]: memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>>,
235+ // CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index,
236+ // CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>, %[[MASK:.+]]: vector<8xi1>,
237+ // CHECK-SAME: %[[PASS:.+]]: vector<8xf32>) -> vector<8xf32> {
238+ // CHECK: %[[BB:.+]], %[[M_OFF:.+]], %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
239+ // CHECK: arith.muli %[[OFF0]], %[[STRIDES]]#0 : index
240+ // CHECK: arith.addi {{.*}} : index
241+ // CHECK-COUNT2: arith.muli {{.*}} : index
242+ // CHECK-COUNT2: arith.addi {{.*}} : index
243+ // CHECK: %[[STRD_INDICES:.+]] = arith.muli {{.*}}%[[INDICES]]{{.*}} : vector<8xindex>
244+ // CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
245+ // CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[STRD_INDICES]] : vector<8xindex>
246+ // CHECK: %[[BASE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<4x8x32xf32, strided<[?, 128, 2], offset: ?>> -> index
247+ // CHECK: %[[BASE_I64:.+]] = arith.index_cast %[[BASE]] : index to i64
248+ // CHECK: %[[V:.+]] = xegpu.load %[[BASE_I64]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
249+ // CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[V]], %[[PASS]] : vector<8xi1>, vector<8xf32>
250+ // CHECK: gpu.return %[[RES]] : vector<8xf32>
251+ }
0 commit comments