Skip to content

Commit bbd43d0

Browse files
committed
polish tests
1 parent c89c5db commit bbd43d0

File tree

2 files changed

+174
-9
lines changed

2 files changed

+174
-9
lines changed

mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir

Lines changed: 145 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,200 @@
1-
// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s
1+
// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s
22

33
gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
44

5-
// e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
5+
// e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
66
// its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
77
//CHECK-LABEL: load_store_matrix_1
88
gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
99
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
10+
11+
//CHECK: %[[TID:.*]] = gpu.thread_id x
12+
//CHECK: %[[C1:.*]] = arith.constant 1 : index
13+
//CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
14+
//CHECK: %[[C4:.*]] = arith.constant 4 : i64
15+
//CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i64
1016
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
17+
1118
%tid_x = gpu.thread_id x
1219
%c0 = arith.constant 0 : index
1320
%1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
21+
22+
//CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
23+
24+
xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
25+
1426
gpu.return %1: f32
1527
}
1628

17-
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
29+
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
1830
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
1931
//CHECK-LABEL: load_store_matrix_2
2032
gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
2133
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
22-
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16
34+
//CHECK: %[[c0:.*]] = arith.constant 0 : index
35+
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
36+
//CHECK: %[[c13:.*]] = arith.constant 13 : index
37+
//CHECK: %[[c16:.*]] = arith.constant 16 : index
38+
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
39+
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
40+
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
41+
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
42+
43+
//CHECK: %[[c256:.*]] = arith.constant 256 : index
44+
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
45+
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
46+
//CHECK: %[[c512:.*]] = arith.constant 512 : index
47+
//CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
48+
//CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
49+
//CHECK: %[[c1:.*]] = arith.constant 1 : index
50+
//CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
51+
//CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
52+
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
53+
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
54+
55+
//CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
56+
57+
2358
%tid_x = gpu.thread_id x
2459
%c13 = arith.constant 13 : index
2560
%1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16
61+
62+
//CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
63+
64+
xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
2665
gpu.return %1: f16
2766
}
2867

68+
2969
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
3070
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
3171
//CHECK-LABEL: load_store_matrix_3
3272
gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
73+
//CHECK: %[[c0:.*]] = arith.constant 0 : index
74+
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
3375
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
34-
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16
76+
77+
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
78+
//CHECK: %[[c19:.*]] = arith.constant 19 : index
3579
%tid_x = gpu.thread_id x
3680
%c19 = arith.constant 19: index
81+
82+
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
83+
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64
84+
//CHECK: %[[c16:.*]] = arith.constant 16 : index
85+
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
86+
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
87+
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
88+
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
89+
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
90+
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
91+
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
92+
//CHECK: %[[c256:.*]] = arith.constant 256 : index
93+
//CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
94+
//CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
95+
//CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
96+
//CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
97+
//CHECK: %[[c1:.*]] = arith.constant 1 : index
98+
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
99+
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
100+
101+
//CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
37102
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
103+
104+
//CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
105+
xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
106+
107+
//CHECK: gpu.return %[[loaded]] : f16
38108
gpu.return %1: f16
39109
}
40-
41-
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
110+
111+
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
42112
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
43113
//CHECK-LABEL: load_store_matrix_4
44114
gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
45115
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
46-
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
116+
117+
//CHECK: %[[c0:.*]] = arith.constant 0 : index
118+
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
119+
120+
//CHECK: %[[c16:.*]] = arith.constant 16 : index
121+
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
122+
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
123+
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
124+
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
125+
126+
//CHECK: %[[c256:.*]] = arith.constant 256 : index
127+
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
128+
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
129+
//CHECK: %[[c512:.*]] = arith.constant 512 : index
130+
//CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
131+
//CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
132+
//CHECK: %[[c1:.*]] = arith.constant 1 : index
133+
//CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
134+
//CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
135+
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
136+
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
137+
138+
//CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
139+
47140
%tid_x = gpu.thread_id x
48141
%c16 = arith.constant 16 : index
49142
%1 = xegpu.load_matrix %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
143+
144+
//CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3>
145+
xegpu.store_matrix %1, %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
146+
50147
gpu.return %1: vector<8xf16>
51148
}
52149

150+
53151
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
54152
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
55153
//CHECK-LABEL: load_store_matrix_5
56154
gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
155+
//CHECK: %[[c0:.*]] = arith.constant 0 : index
156+
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
157+
57158
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
58-
//CHECK: xevm.blockload {{.*}} : (!llvm.ptr<3>) -> vector<8xi16>
159+
160+
//CHECK: %[[c16:.*]] = arith.constant 16 : index
161+
//CHECK: %[[c48:.*]] = arith.constant 48 : index
162+
59163
%c16 = arith.constant 16 : index
60164
%c48 = arith.constant 48 : index
165+
166+
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
167+
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64
168+
//CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
169+
//CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
170+
//CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
171+
//CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
172+
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
173+
//CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
174+
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
175+
//CHECK: %[[c256:.*]] = arith.constant 256 : index
176+
//CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index
177+
//CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
178+
//CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index
179+
//CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
180+
//CHECK: %[[c1:.*]] = arith.constant 1 : index
181+
//CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index
182+
//CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index
183+
//CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i64
184+
//CHECK: %[[c2:.*]] = arith.constant 2 : i64
185+
//CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i64
186+
//CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i64
187+
//CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i64 to !llvm.ptr<3>
188+
//CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
189+
//CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
190+
61191
%1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
192+
193+
//CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16>
194+
//CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
195+
196+
xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
197+
62198
gpu.return %1: vector<8xf16>
63199
}
64200

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,21 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
870870
return
871871
}
872872

873+
// -----
874+
func.func @load_mem_desc_invalid_attr1(%arg0: !xegpu.mem_desc<16x64xf16>) {
875+
// expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
876+
%data1 = xegpu.load_matrix %arg0[8, 8]<{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
877+
return
878+
}
879+
880+
// -----
881+
func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
882+
// expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
883+
%data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
884+
return
885+
}
886+
887+
873888
// -----
874889
func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
875890
// expected-error@+1 {{failed to verify that all of {mem_desc, data} have same element type}}
@@ -891,6 +906,20 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
891906
return
892907
}
893908

909+
// -----
910+
func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
911+
// expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
912+
xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
913+
return
914+
}
915+
916+
// -----
917+
func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
918+
// expected-error@+1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
919+
xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
920+
return
921+
}
922+
894923
// -----
895924
func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
896925
// expected-error@+1 {{result shape must not exceed source shape}}

0 commit comments

Comments
 (0)