@@ -33,22 +33,29 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
3333
3434 %subview = memref.subview %view [32 , 0 ] [32 , 32 ] [1 , 1 ] : memref <64 x32 xf32 , 3 > to memref <32 x32 xf32 , strided <[32 , 1 ], offset : 1024 >, 3 >
3535
36+ //CHECK: %[[base_buffer:.*]], %[[offset:.*]], %[[sizes:.*]]:2, %[[strides:.*]]:2 = memref.extract_strided_metadata %{{.*}} : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> memref<f32, 3>, index, index, index, index, index
37+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer]] : memref<f32, 3> -> index
38+ //CHECK: %[[ptr_i32:.*]] = arith.index_castui %[[intptr]] : index to i32
39+ //CHECK: %[[offset_i32:.*]] = arith.index_castui %[[offset]] : index to i32
40+ //CHECK: %[[c4_i32:.*]] = arith.constant 4 : i32
41+ //CHECK: %[[mul:.*]] = arith.muli %[[offset_i32]], %[[c4_i32]] : i32
42+ //CHECK: %[[add:.*]] = arith.addi %[[ptr_i32]], %[[mul]] : i32
43+
3644 %0 = xegpu.create_mem_desc %subview : memref <32 x32 xf32 , strided <[32 , 1 ], offset : 1024 >, 3 > -> !xegpu.mem_desc <32 x32 xf32 >
3745
3846 //CHECK: %[[TID:.*]] = gpu.thread_id x
3947 //CHECK: %[[C1:.*]] = arith.constant 1 : index
4048 //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
41- //CHECK: %[[C4:.*]] = arith.constant 4 : i32
42- //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
49+ //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, {{.*}} : i32
4350 //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
4451
4552 %tid_x = gpu.thread_id x
46-
53+
4754 %1 = xegpu.load_matrix %0 [%c0 , %tid_x ]: !xegpu.mem_desc <32 x32 xf32 >, index , index -> f32
4855
4956 //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
5057
51- xegpu.store_matrix %1 , %0 [%c0 , %tid_x ]: f32 , !xegpu.mem_desc <32 x32 xf32 >, index , index
58+ xegpu.store_matrix %1 , %0 [%c0 , %tid_x ]: f32 , !xegpu.mem_desc <32 x32 xf32 >, index , index
5259
5360 gpu.return %1: f32
5461 }
@@ -99,8 +106,6 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
99106 //CHECK-LABEL: load_store_matrix_blocked_nostride
100107 gpu.func @load_store_matrix_blocked_nostride (%arg0: memref <4096 xi8 , 3 >) -> f16 {
101108
102- //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
103- //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
104109 %0 = xegpu.create_mem_desc %arg0 : memref <4096 xi8 , 3 > -> !xegpu.mem_desc <32 x64 xf16 , #xegpu.mem_layout <block = [16 , 16 ]>>
105110
106111 //CHECK: %[[tid_x:.*]] = gpu.thread_id x
@@ -178,7 +183,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
178183 //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
179184 gpu.func @load_store_matrix_blocked_subgroupblockio (%arg0: memref <4096 xi8 , 3 >) -> vector <8 xf16 > {
180185
181- //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
186+ //CHECK: %[[base_buffer:.*]], %[[offset:.*]], %[[sizes:.*]], %[[strides:.*]] = memref.extract_strided_metadata %arg0 : memref<4096xi8, 3> -> memref<i8, 3>, index, index, index
187+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer]] : memref<i8, 3> -> index
182188 //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
183189 %0 = xegpu.create_mem_desc %arg0 : memref <4096 xi8 , 3 > -> !xegpu.mem_desc <32 x64 xf16 , #xegpu.mem_layout <block = [16 , 16 ]>>
184190
@@ -206,7 +212,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
206212 //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
207213 //CHECK: %[[c2:.*]] = arith.constant 2 : i32
208214 //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
209- //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32]], %[[byteOffset]] : i32
215+ //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32:.* ]], %[[byteOffset]] : i32
210216 //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
211217 //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
212218 //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
0 commit comments