@@ -4,8 +4,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
44
55 // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
66 // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
7- //CHECK-LABEL: load_store_matrix_1
8- gpu.func @load_store_matrix_1 (%arg0: memref <4096 xi8 , 3 >) -> f32 {
7+ //CHECK-LABEL: load_store_matrix_plain
8+ gpu.func @load_store_matrix_plain (%arg0: memref <4096 xi8 , 3 >) -> f32 {
99 %0 = xegpu.create_mem_desc %arg0 : memref <4096 xi8 , 3 > -> !xegpu.mem_desc <32 x32 xf32 >
1010
1111 //CHECK: %[[TID:.*]] = gpu.thread_id x
@@ -26,20 +26,48 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
2626 gpu.return %1: f32
2727 }
2828
29+ //CHECK-LABEL: load_store_matrix_plain_2d_input
30+ gpu.func @load_store_matrix_plain_2d_input (%arg0: memref <8192 xi8 , 3 >) -> f32 {
31+ %c0 = arith.constant 0 : index
32+ %view = memref.view %arg0 [%c0 ][]: memref <8192 xi8 , 3 > to memref <64 x32 xf32 , 3 >
33+
34+ %subview = memref.subview %view [32 , 0 ] [32 , 32 ] [1 , 1 ] : memref <64 x32 xf32 , 3 > to memref <32 x32 xf32 , strided <[32 , 1 ], offset : 1024 >, 3 >
35+
36+ %0 = xegpu.create_mem_desc %subview : memref <32 x32 xf32 , strided <[32 , 1 ], offset : 1024 >, 3 > -> !xegpu.mem_desc <32 x32 xf32 >
37+
38+ //CHECK: %[[TID:.*]] = gpu.thread_id x
39+ //CHECK: %[[C1:.*]] = arith.constant 1 : index
40+ //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
41+ //CHECK: %[[C4:.*]] = arith.constant 4 : i32
42+ //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
43+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
44+
45+ %tid_x = gpu.thread_id x
46+
47+ %1 = xegpu.load_matrix %0 [%c0 , %tid_x ]: !xegpu.mem_desc <32 x32 xf32 >, index , index -> f32
48+
49+ //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
50+
51+ xegpu.store_matrix %1 , %0 [%c0 , %tid_x ]: f32 , !xegpu.mem_desc <32 x32 xf32 >, index , index
52+
53+ gpu.return %1: f32
54+ }
55+
56+
2957// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
3058 // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
31- //CHECK-LABEL: load_store_matrix_2
32- gpu.func @load_store_matrix_2 (%arg0: memref <4096 xi8 , 3 >) -> f16 {
59+ //CHECK-LABEL: load_store_matrix_blocked_strided
60+ gpu.func @load_store_matrix_blocked_strided (%arg0: memref <4096 xi8 , 3 >) -> f16 {
3361 %0 = xegpu.create_mem_desc %arg0 : memref <4096 xi8 , 3 > -> !xegpu.mem_desc <32 x64 xf16 , #xegpu.mem_layout <stride = [1 , 32 ], block = [16 , 16 ]>>
34- //CHECK: %[[c0:.*]] = arith.constant 0 : index
62+
3563 //CHECK: %[[tid_x:.*]] = gpu.thread_id x
3664 //CHECK: %[[c13:.*]] = arith.constant 13 : index
3765 //CHECK: %[[c16:.*]] = arith.constant 16 : index
3866 //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
3967 //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
4068 //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
4169 //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
42-
70+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
4371 //CHECK: %[[c256:.*]] = arith.constant 256 : index
4472 //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
4573 //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -68,24 +96,24 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
6896
6997 // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
7098 // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
71- //CHECK-LABEL: load_store_matrix_3
72- gpu.func @load_store_matrix_3 (%arg0: memref <4096 xi8 , 3 >) -> f16 {
73- //CHECK: %[[c0:.*]] = arith.constant 0 : index
74- //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
99+ //CHECK-LABEL: load_store_matrix_blocked_nostride
100+ gpu.func @load_store_matrix_blocked_nostride (%arg0: memref <4096 xi8 , 3 >) -> f16 {
101+
102+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
103+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
75104 %0 = xegpu.create_mem_desc %arg0 : memref <4096 xi8 , 3 > -> !xegpu.mem_desc <32 x64 xf16 , #xegpu.mem_layout <block = [16 , 16 ]>>
76105
77106 //CHECK: %[[tid_x:.*]] = gpu.thread_id x
78107 //CHECK: %[[c19:.*]] = arith.constant 19 : index
79108 %tid_x = gpu.thread_id x
80109 %c19 = arith.constant 19 : index
81110
82- //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
83- //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
84111 //CHECK: %[[c16:.*]] = arith.constant 16 : index
85112 //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
86113 //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
87114 //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
88115 //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
116+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
89117 //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
90118 //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
91119 //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -97,7 +125,6 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
97125 //CHECK: %[[c1:.*]] = arith.constant 1 : index
98126 //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
99127 //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
100-
101128 //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
102129 %1 = xegpu.load_matrix %0 [%c19 , %tid_x ]: !xegpu.mem_desc <32 x64 xf16 , #xegpu.mem_layout <block = [16 , 16 ]>>, index , index -> f16
103130
@@ -110,19 +137,17 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
110137
111138 // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
112139 // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
113- //CHECK-LABEL: load_store_matrix_4
114- gpu.func @load_store_matrix_4 (%arg0: memref <4096 xi8 , 3 >) -> vector <8 xf16 > {
140+ //CHECK-LABEL: load_store_matrix_blocked_strided_return_vector
141+ gpu.func @load_store_matrix_blocked_strided_return_vector (%arg0: memref <4096 xi8 , 3 >) -> vector <8 xf16 > {
115142 %0 = xegpu.create_mem_desc %arg0 : memref <4096 xi8 , 3 > -> !xegpu.mem_desc <32 x64 xf16 , #xegpu.mem_layout <stride = [1 , 32 ], block = [16 , 16 ]>>
116143
117- //CHECK: %[[c0:.*]] = arith.constant 0 : index
118144 //CHECK: %[[tid_x:.*]] = gpu.thread_id x
119-
120145 //CHECK: %[[c16:.*]] = arith.constant 16 : index
121146 //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
122147 //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
123148 //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
124149 //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
125-
150+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
126151 //CHECK: %[[c256:.*]] = arith.constant 256 : index
127152 //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
128153 //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -150,25 +175,23 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
150175
151176 // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
152177 // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
153- //CHECK-LABEL: load_store_matrix_5
154- gpu.func @load_store_matrix_5 (%arg0: memref <4096 xi8 , 3 >) -> vector <8 xf16 > {
155- //CHECK: %[[c0:.*]] = arith.constant 0 : index
156- //CHECK: %[[view :.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
157-
158- %0 = xegpu.create_mem_desc %arg0 : memref <4096 xi8 , 3 > -> !xegpu.mem_desc <32 x64 xf16 , #xegpu.mem_layout <block = [16 , 16 ]>>
159-
178+ //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
179+ gpu.func @load_store_matrix_blocked_subgroupblockio (%arg0: memref <4096 xi8 , 3 >) -> vector <8 xf16 > {
180+
181+ //CHECK: %[[intptr :.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
182+ //CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
183+ %0 = xegpu.create_mem_desc %arg0 : memref <4096 xi8 , 3 > -> !xegpu.mem_desc <32 x64 xf16 , #xegpu.mem_layout <block = [16 , 16 ]>>
184+
160185 //CHECK: %[[c16:.*]] = arith.constant 16 : index
161186 //CHECK: %[[c48:.*]] = arith.constant 48 : index
162-
163187 %c16 = arith.constant 16 : index
164188 %c48 = arith.constant 48 : index
165189
166- //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
167- //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
168190 //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
169191 //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
170192 //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
171193 //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
194+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
172195 //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
173196 //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
174197 //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -183,7 +206,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
183206 //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
184207 //CHECK: %[[c2:.*]] = arith.constant 2 : i32
185208 //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
186- //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64 ]], %[[byteOffset]] : i32
209+ //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32 ]], %[[byteOffset]] : i32
187210 //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
188211 //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
189212 //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
0 commit comments