@@ -4,8 +4,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
44
55 // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
66 // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
7- //CHECK-LABEL: load_store_matrix_1
8- gpu.func @load_store_matrix_1 (%arg0: memref <4096 xi8 , 3 >) -> f32 {
7+ //CHECK-LABEL: load_store_matrix_plain
8+ gpu.func @load_store_matrix_plain (%arg0: memref <4096 xi8 , 3 >) -> f32 {
99 %0 = xegpu.create_mem_desc %arg0 : memref <4096 xi8 , 3 > -> !xegpu.mem_desc <32 x32 xf32 >
1010
1111 //CHECK: %[[TID:.*]] = gpu.thread_id x
@@ -26,10 +26,38 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
2626 gpu.return %1: f32
2727 }
2828
29+ //CHECK-LABEL: load_store_matrix_plain_2d_input
30+ gpu.func @load_store_matrix_plain (%arg0: memref <8192 xi8 , 3 >) -> f32 {
31+
32+ %view = memref.view %arg0 [0 ][]: memref <8192 xi8 , 3 > to memref <64 x32 xf32 , 3 >
33+
34+ %subview = memref.subview %view [64 , 0 ] [64 , 128 ] [1 , 1 ] : memref <64 x32 xf32 , 3 > to memref <32 x32 xf32 , strided <[32 , 1 ], offset : 1024 >, 3 >
35+
36+ %0 = xegpu.create_mem_desc %subview : memref <32 x32 xf32 , strided <[32 , 1 ], offset : 1024 >, 3 > -> !xegpu.mem_desc <32 x32 xf32 >
37+
38+ //CHECK: %[[TID:.*]] = gpu.thread_id x
39+ //CHECK: %[[C1:.*]] = arith.constant 1 : index
40+ //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
41+ //CHECK: %[[C4:.*]] = arith.constant 4 : i32
42+ //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
43+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
44+
45+ %tid_x = gpu.thread_id x
46+ %c0 = arith.constant 0 : index
47+ %1 = xegpu.load_matrix %0 [%c0 , %tid_x ]: !xegpu.mem_desc <32 x32 xf32 >, index , index -> f32
48+
49+ //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
50+
51+ xegpu.store_matrix %1 , %0 [%c0 , %tid_x ]: f32 , !xegpu.mem_desc <32 x32 xf32 >, index , index
52+
53+ gpu.return %1: f32
54+ }
55+
56+
2957// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
3058 // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
31- //CHECK-LABEL: load_store_matrix_2
32- gpu.func @load_store_matrix_2 (%arg0: memref <4096 xi8 , 3 >) -> f16 {
59+ //CHECK-LABEL: load_store_matrix_blocked_strided
60+ gpu.func @load_store_matrix_blocked_strided (%arg0: memref <4096 xi8 , 3 >) -> f16 {
3361 %0 = xegpu.create_mem_desc %arg0 : memref <4096 xi8 , 3 > -> !xegpu.mem_desc <32 x64 xf16 , #xegpu.mem_layout <stride = [1 , 32 ], block = [16 , 16 ]>>
3462 //CHECK: %[[c0:.*]] = arith.constant 0 : index
3563 //CHECK: %[[tid_x:.*]] = gpu.thread_id x
@@ -68,8 +96,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
6896
6997 // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
7098 // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
71- //CHECK-LABEL: load_store_matrix_3
72- gpu.func @load_store_matrix_3 (%arg0: memref <4096 xi8 , 3 >) -> f16 {
99+ //CHECK-LABEL: load_store_matrix_blocked_nostride
100+ gpu.func @load_store_matrix_blocked_nostride (%arg0: memref <4096 xi8 , 3 >) -> f16 {
73101 //CHECK: %[[c0:.*]] = arith.constant 0 : index
74102 //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
75103 %0 = xegpu.create_mem_desc %arg0 : memref <4096 xi8 , 3 > -> !xegpu.mem_desc <32 x64 xf16 , #xegpu.mem_layout <block = [16 , 16 ]>>
@@ -110,8 +138,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
110138
111139 // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
112140 // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
113- //CHECK-LABEL: load_store_matrix_4
114- gpu.func @load_store_matrix_4 (%arg0: memref <4096 xi8 , 3 >) -> vector <8 xf16 > {
141+ //CHECK-LABEL: load_store_matrix_blocked_strided_return_vector
142+ gpu.func @load_store_matrix_blocked_strided_return_vector (%arg0: memref <4096 xi8 , 3 >) -> vector <8 xf16 > {
115143 %0 = xegpu.create_mem_desc %arg0 : memref <4096 xi8 , 3 > -> !xegpu.mem_desc <32 x64 xf16 , #xegpu.mem_layout <stride = [1 , 32 ], block = [16 , 16 ]>>
116144
117145 //CHECK: %[[c0:.*]] = arith.constant 0 : index
@@ -150,8 +178,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
150178
151179 // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
152180 // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
153- //CHECK-LABEL: load_store_matrix_5
154- gpu.func @load_store_matrix_5 (%arg0: memref <4096 xi8 , 3 >) -> vector <8 xf16 > {
181+ //CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
182+ gpu.func @load_store_matrix_blocked_subgroupblockio (%arg0: memref <4096 xi8 , 3 >) -> vector <8 xf16 > {
155183 //CHECK: %[[c0:.*]] = arith.constant 0 : index
156184 //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
157185
0 commit comments