@@ -27,8 +27,9 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
2727// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
2828// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
2929// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
30- // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
31- // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
30+ // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
31+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
32+ // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf32>
3233
3334}
3435
@@ -62,8 +63,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
6263// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
6364// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
6465// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
65- // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
66- // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
66+ // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<8x16x32xf32> -> index
67+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
68+ // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
6769
6870}
6971
@@ -124,8 +126,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
124126// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
125127// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
126128// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
127- // LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
128- // LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
129+ // LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf32> -> index
130+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
131+ // LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
129132
130133}
131134
@@ -164,8 +167,9 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
164167// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
165168// LOAD-GATHER: %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
166169// LOAD-GATHER: %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
167- // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
168- // LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
170+ // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?x?xf32> -> index
171+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
172+ // LOAD-GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[FINALIDX]]{{\]}}, %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
169173// LOAD-GATHER: gpu.return %[[RES]] : vector<8x16xf32>
170174}
171175
@@ -195,8 +199,9 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
195199// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
196200// LOAD-GATHER-DAG: %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
197201// LOAD-GATHER-DAG: %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
198- // LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
199- // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
202+ // LOAD-GATHER-DAG: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<?x8x16xf32> -> index
203+ // LOAD-GATHER-DAG: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
204+ // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
200205
201206}
202207
@@ -224,8 +229,9 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
224229// LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
225230// LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
226231// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
227- // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
228- // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
232+ // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SRC]] : memref<?x?x?x?x?xf32> -> index
233+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
234+ // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
229235// LOAD-GATHER: return %[[VEC]]
230236}
231237
@@ -254,8 +260,9 @@ gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
254260// LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
255261// LOAD-GATHER: %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
256262// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
257- // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
258- // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
263+ // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %arg0 : memref<16x32x64xf32> -> index
264+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
265+ // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
259266
260267}
261268
@@ -283,8 +290,9 @@ gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
283290// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
284291// LOAD-GATHER: %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
285292// LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
286- // LOAD-GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
287- // LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
293+ // LOAD-GATHER: %[[COLLAPSE:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<32x64xf16> -> index
294+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
295+ // LOAD-GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE_I]][%[[IDX]]], %[[CST]] : i64, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
288296}
289297
290298// -----
@@ -396,3 +404,40 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
396404// LOAD-GATHER: vector.transfer_read
397405}
398406
407+ // -----
408+ gpu.module @xevm_module {
409+ gpu.func @load_from_subview (%source: memref <4096 x4096 xf16 >, %off1: index , %off2: index ) -> vector <8 xf16 > {
410+ %c0 = arith.constant 0.0 : f16
411+ %subview = memref.subview %source [%off1 , %off2 ] [256 , 256 ] [1 , 1 ] : memref <4096 x4096 xf16 > to memref <256 x256 xf16 , strided <[4096 , 1 ], offset : ?>>
412+ %0 = vector.transfer_read %subview [%off2 , %off2 ], %c0
413+ {in_bounds = [true ]} : memref <256 x256 xf16 , strided <[4096 , 1 ], offset : ?>>, vector <8 xf16 >
414+ gpu.return %0 : vector <8 xf16 >
415+ }
416+
417+ // LOAD-ND-LABEL: @load_from_subview(
418+ // LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
419+ // LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
420+ // LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
421+ // LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
422+ // LOAD-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
423+ // LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
424+ // LOAD-ND-SAME: boundary_check = false
425+ // LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
426+ // LOAD-ND: return %[[VEC]]
427+
428+ // LOAD-GATHER-LABEL: @load_from_subview(
429+ // LOAD-GATHER-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
430+ // LOAD-GATHER-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
431+ // LOAD-GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
432+ // LOAD-GATHER: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
433+ // LOAD-GATHER: %[[BB:.+]], %[[OFFSET:.+]],{{.*}},{{.*}} = memref.extract_strided_metadata %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
434+ // LOAD-GATHER: %[[STEP:.+]] = vector.step : vector<8xindex>
435+ // LOAD-GATHER: arith.muli {{.*}} : index
436+ // LOAD-GATHER: arith.addi %[[OFFSET]]{{.*}} : index
437+ // LOAD-GATHER: arith.addi {{.*}} : index
438+ // LOAD-GATHER: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
439+ // LOAD-GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
440+ // LOAD-GATHER: %[[COLLAPSE:.+]] = memref.extract_aligned_pointer_as_index %[[SUBVIEW]] : memref<256x256xf16, strided<[4096, 1], offset: ?>> -> index
441+ // LOAD-GATHER: %[[COLLAPSE_I:.+]] = arith.index_cast %[[COLLAPSE]] : index to i64
442+ // LOAD-GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE_I]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : i64, vector<8xindex>, vector<8xi1> -> vector<8xf16>
443+ }
0 commit comments