Skip to content

Commit 061b6e0

Browse files
committed
add 1d and 2d elemwise test
1 parent aa4ba9c commit 061b6e0

File tree

1 file changed

+93
-11
lines changed

1 file changed

+93
-11
lines changed

mlir/test/Dialect/XeGPU/xegpu-blocking.mlir

Lines changed: 93 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
// RUN: mlir-opt --xegpu-blocking -split-input-file %s | FileCheck %s
22

3-
43
#a = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
54
#b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]>
65
#c = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
7-
8-
#l1 = #xegpu.layout<inst_data = [8, 16]>
9-
#l2 = #xegpu.layout<inst_data = [16, 16]>
10-
116
gpu.module @test_kernel {
127
gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
138
%c0 = arith.constant 0 : index
@@ -44,9 +39,13 @@ gpu.module @test_kernel {
4439
xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #c>
4540
gpu.return
4641
}
42+
}
4743

48-
//-----
49-
gpu.func @test_gemm_simple(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
44+
// -----
45+
#l1 = #xegpu.layout<inst_data = [8, 16]>
46+
#l2 = #xegpu.layout<inst_data = [16, 16]>
47+
gpu.module @test_kernel {
48+
gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
5049
%c0 = arith.constant 0 : index
5150
%c16 = arith.constant 16 : index
5251
%c32 = arith.constant 32 : index
@@ -81,10 +80,14 @@ gpu.module @test_kernel {
8180
xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #l1>
8281
gpu.return
8382
}
83+
}
8484

85-
//-----
86-
87-
gpu.func @test_gemm_a_preop(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
85+
// -----
86+
#a = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
87+
#b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]>
88+
#c = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
89+
gpu.module @test_kernel {
90+
gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
8891
%c0 = arith.constant 0 : index
8992
%c16 = arith.constant 16 : index
9093
%c32 = arith.constant 32 : index
@@ -120,4 +123,83 @@ gpu.module @test_kernel {
120123
//CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>>
121124
xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #c>
122125
gpu.return
123-
}}
126+
}
127+
}
128+
129+
// -----
130+
#l = #xegpu.layout<inst_data = [8, 16]>
131+
gpu.module @test_kernel {
132+
gpu.func @test_elementwise(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) {
133+
%c0 = arith.constant 0 : index
134+
%c32 = arith.constant 32 : index
135+
%c1024 = arith.constant 1024 : index
136+
%block_id_x = gpu.block_id x
137+
%block_id_y = gpu.block_id y
138+
%m = arith.muli %block_id_x, %c32 : index
139+
140+
%a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l>
141+
%b_tdesc = xegpu.create_nd_tdesc %B[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l>
142+
%c_tdesc = xegpu.create_nd_tdesc %C[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l>
143+
144+
%out:3 = scf.for %k = %c0 to %c1024 step %c32
145+
iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc)
146+
-> (!xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>) {
147+
//CHECK-COUNT-8: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
148+
%a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #l> -> vector<16x32xf16>
149+
%b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16, #l> -> vector<16x32xf16>
150+
151+
//CHECK-COUNT-4: arith.addf {{.*}} : vector<8x16xf16>
152+
%c = arith.addf %a, %b {layout_result_0 = #l} : vector<16x32xf16>
153+
154+
//CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
155+
xegpu.store_nd %c, %arg2: vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16, #l>
156+
157+
//CHECK-COUNT-12: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16>
158+
%a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l>
159+
%b_next_tdesc = xegpu.update_nd_offset %arg1, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l>
160+
%c_next_tdesc = xegpu.update_nd_offset %arg2, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l>
161+
scf.yield %a_next_tdesc, %b_next_tdesc, %c_next_tdesc
162+
: !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>
163+
}
164+
gpu.return
165+
}
166+
}
167+
168+
// -----
169+
#l = #xegpu.layout<inst_data = [8]>
170+
gpu.module @test_kernel {
171+
gpu.func @test_elementwise(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) {
172+
%c0 = arith.constant 0 : index
173+
%c32 = arith.constant 32 : index
174+
%c1024 = arith.constant 1024 : index
175+
%block_id_x = gpu.block_id x
176+
%block_id_y = gpu.block_id y
177+
%m = arith.muli %block_id_x, %c32 : index
178+
179+
%a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l>
180+
%b_tdesc = xegpu.create_nd_tdesc %B[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l>
181+
%c_tdesc = xegpu.create_nd_tdesc %C[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l>
182+
183+
%out:3 = scf.for %k = %c0 to %c1024 step %c32
184+
iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc)
185+
-> (!xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>) {
186+
//CHECK-COUNT-8: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8xf16> -> vector<8xf16>
187+
%a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<32xf16, #l> -> vector<32xf16>
188+
%b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32xf16, #l> -> vector<32xf16>
189+
190+
//CHECK-COUNT-4: arith.addf {{.*}} : vector<8xf16>
191+
%c = arith.addf %a, %b {layout_result_0 = #l} : vector<32xf16>
192+
193+
//CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8xf16>, !xegpu.tensor_desc<8xf16>
194+
xegpu.store_nd %c, %arg2: vector<32xf16>, !xegpu.tensor_desc<32xf16, #l>
195+
196+
//CHECK-COUNT-12: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8xf16>
197+
%a_next_tdesc = xegpu.update_nd_offset %arg0, [%c32] : !xegpu.tensor_desc<32xf16, #l>
198+
%b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32] : !xegpu.tensor_desc<32xf16, #l>
199+
%c_next_tdesc = xegpu.update_nd_offset %arg2, [%c32] : !xegpu.tensor_desc<32xf16, #l>
200+
scf.yield %a_next_tdesc, %b_next_tdesc, %c_next_tdesc
201+
: !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>
202+
}
203+
gpu.return
204+
}
205+
}

0 commit comments

Comments
 (0)