Skip to content

Commit 8c725c9

Browse files
authored
[XeGPUToXeVM] Add SCF lowering support with simple e2e GEMM test case. (#1074)
[XeGPUToXeVM] Add GEMM lowering support with simple e2e test case.
1 parent 227f725 commit 8c725c9

File tree

3 files changed

+186
-0
lines changed

3 files changed

+186
-0
lines changed

lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1616
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
1819
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1920
#include "mlir/Pass/Pass.h"
2021
#include "mlir/Support/LLVM.h"
@@ -677,6 +678,8 @@ struct ConvertXeGPUToXeVMPass
677678

678679
RewritePatternSet patterns(&getContext());
679680
imex::populateXeGPUToXeVMConversionPatterns(patterns, typeConverter);
681+
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
682+
typeConverter, patterns, target);
680683
if (failed(applyPartialConversion(getOperation(), target,
681684
std::move(patterns))))
682685
signalPassFailure();
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
// RUN: %python_executable %imex_runner --requires=l0-runtime,spirv-backend -i %s --pass-pipeline-file=%p/xegpu-to-llvm.pp \
2+
// RUN: --runner imex-cpu-runner -e main \
3+
// RUN: --entry-point-result=void \
4+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
5+
6+
7+
module @gemm attributes {gpu.container_module} {
8+
gpu.module @kernel {
9+
gpu.func @load_store_2d_dpas(%a: memref<256x256xf16>, %b: memref<256x256xf16>, %c: memref<256x256xf32>) kernel {
10+
%c0 = arith.constant 0 : index
11+
%c1 = arith.constant 1 : index
12+
%c8 = arith.constant 8 : index
13+
%c16 = arith.constant 16 : index
14+
%c32 = arith.constant 32 : index
15+
%c256 = arith.constant 256 : index
16+
%block_x = gpu.block_id x
17+
%block_y = gpu.block_id y
18+
%x_block_offset = arith.muli %block_x, %c8 : index
19+
%y_block_offset = arith.muli %block_y, %c16 : index
20+
21+
%c_tdesc = xegpu.create_nd_tdesc %c[%x_block_offset, %y_block_offset] : memref<256x256xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
22+
%c_init_value = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>> -> vector<8xf32>
23+
24+
%r = scf.for %k = %c0 to %c256 step %c16 iter_args(%arg_c = %c_init_value) -> ( vector<8xf32>) {
25+
// TODO: There is issue with update_nd_offset. To avoid it, we use create_nd_tdesc here.
26+
%a_tdesc_new = xegpu.create_nd_tdesc %a[%x_block_offset, %k] : memref<256x256xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_space = global>>
27+
%b_tdesc_new = xegpu.create_nd_tdesc %b[%k, %y_block_offset] : memref<256x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global>>
28+
%a_val = xegpu.load_nd %a_tdesc_new : !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<memory_space = global>> -> vector<8xf16>
29+
%b_val = xegpu.load_nd %b_tdesc_new : !xegpu.tensor_desc<16x16xf16, #xegpu.block_tdesc_attr<memory_space = global>> -> vector<16xf16>
30+
%dpas = xegpu.dpas %a_val, %b_val, %arg_c : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
31+
scf.yield %dpas : vector<8xf32>
32+
}
33+
xegpu.store_nd %r, %c_tdesc <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
34+
gpu.return
35+
}
36+
}
37+
38+
func.func @test(%a : memref<256x256xf16>, %b : memref<256x256xf16>, %c : memref<256x256xf32>) -> memref<256x256xf32> attributes {llvm.emit_c_interface} {
39+
%c1 = arith.constant 1 : index
40+
%c16 = arith.constant 16 : index
41+
%c32 = arith.constant 32 : index
42+
%memref_a = gpu.alloc host_shared () : memref<256x256xf16>
43+
memref.copy %a, %memref_a : memref<256x256xf16> to memref<256x256xf16>
44+
%memref_b = gpu.alloc host_shared () : memref<256x256xf16>
45+
memref.copy %b, %memref_b : memref<256x256xf16> to memref<256x256xf16>
46+
%memref_c = gpu.alloc host_shared () : memref<256x256xf32>
47+
memref.copy %c, %memref_c : memref<256x256xf32> to memref<256x256xf32>
48+
49+
gpu.launch_func @kernel::@load_store_2d_dpas blocks in (%c32, %c16, %c1) threads in (%c16, %c1, %c1) args(%memref_a : memref<256x256xf16>, %memref_b : memref<256x256xf16>, %memref_c : memref<256x256xf32>)
50+
return %memref_c : memref<256x256xf32>
51+
}
52+
53+
// compute CPU reference (takes minutes)
54+
func.func @cpu_reference(%A : memref<256x256xf16>, %B : memref<256x256xf16>, %C : memref<256x256xf32>) {
55+
%c256 = arith.constant 256 : index
56+
%c16 = arith.constant 16 : index
57+
%c1 = arith.constant 1 : index
58+
%c0 = arith.constant 0 : index
59+
scf.for %i = %c0 to %c256 step %c1 {
60+
scf.for %j = %c0 to %c256 step %c1 {
61+
%c_curr = memref.load %C[%i, %j] : memref<256x256xf32>
62+
%c_val = scf.for %k_tile = %c0 to %c256 step %c16 iter_args(%c_partial = %c_curr) -> f32 {
63+
%c_val_dpas = scf.for %k = %c0 to %c16 step %c1 iter_args(%c_dpas_partial = %c_partial) -> f32 {
64+
%k_dpas = arith.addi %k_tile, %k : index
65+
%a_val = memref.load %A[%i, %k_dpas] : memref<256x256xf16>
66+
%b_val = memref.load %B[%k_dpas, %j] : memref<256x256xf16>
67+
%a_cast = arith.extf %a_val : f16 to f32
68+
%b_cast = arith.extf %b_val : f16 to f32
69+
%t = arith.mulf %a_cast, %b_cast : f32
70+
// %t_cast = arith.extf %t : f16 to f16
71+
%c_sum = arith.addf %t, %c_dpas_partial : f32
72+
scf.yield %c_sum : f32
73+
}
74+
scf.yield %c_val_dpas : f32
75+
}
76+
// %c_val_f16 = arith.truncf %c_val : f32 to f16
77+
// %c_val_ = arith.extf %c_val_f16 : f16 to f32
78+
memref.store %c_val , %C[%i, %j] : memref<256x256xf32>
79+
}
80+
}
81+
return
82+
}
83+
84+
85+
func.func @main() attributes {llvm.emit_c_interface} {
86+
%c0 = arith.constant 0 : index
87+
%c1 = arith.constant 1 : index
88+
%c1_f16 = arith.constant 1.0 : f16
89+
%c2_f16 = arith.constant 2.0 : f16
90+
%c256 = arith.constant 256 : index
91+
%cf_0 = arith.constant 0.0 : f16
92+
%cf_1 = arith.constant 1.0 : f16
93+
%A = memref.alloc() : memref<256x256xf16>
94+
%B = memref.alloc() : memref<256x256xf16>
95+
%C = memref.alloc() : memref<256x256xf32>
96+
%C_ref = memref.alloc() : memref<256x256xf32>
97+
%c_gen_int = arith.constant 0 : i1
98+
%cf_lower = arith.constant -0.5 : f32
99+
%cf_upper = arith.constant 0.5 : f32
100+
// Use one of the two options to initialize the A matrix
101+
// Option 1: intialize matrix A ; A[i, j] = j
102+
// scf.for %i = %c0 to %c256 step %c1 {
103+
// scf.for %j = %c0 to %c256 step %c1 {
104+
// %t = index.castu %j : index to i16
105+
// %val = arith.uitofp %t : i16 to f16
106+
// memref.store %val, %A[%i, %j] : memref<256x256xf16>
107+
// // memref.store %c1_f16, %A[%i, %j] : memref<256x256xf16>
108+
// // memref.store %c2_f16, %B[%i, %j] : memref<256x256xf16>
109+
// }
110+
// }
111+
// Option 2: convert the memref to 1D and fill with random values in (-0.5, 0.5)
112+
%A_random = memref.cast %A : memref<256x256xf16> to memref<*xf16>
113+
call @fillResource1DRandomF16(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> ()
114+
115+
116+
// Use one of the two options below to initialize the B matrix
117+
// Option 1: make matrix B an identity matrix
118+
// scf.for %i = %c0 to %c256 step %c1 {
119+
// scf.for %j = %c0 to %c256 step %c1 {
120+
// %i_i32 = index.castu %i : index to i32
121+
// %j_i32 = index.castu %j : index to i32
122+
// %i_j_same = arith.cmpi eq, %i_i32, %j_i32 : i32
123+
124+
// scf.if %i_j_same {
125+
// memref.store %cf_1, %B[%i, %j] : memref<256x256xf16>
126+
// } else {
127+
// memref.store %cf_0, %B[%i, %j] : memref<256x256xf16>
128+
// }
129+
// }
130+
// }
131+
// Option 2: convert the memref to 1D and fill with random values in (-0.5, 0.5)
132+
%B_random = memref.cast %B : memref<256x256xf16> to memref<*xf16>
133+
call @fillResource1DRandomF16(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf16>, f32, f32, i1) -> ()
134+
135+
136+
// intialize matrix C and C_ref ; C[i, j] = 0
137+
%c0_f32 = arith.constant 0.0 : f32
138+
scf.for %i = %c0 to %c256 step %c1 {
139+
scf.for %j = %c0 to %c256 step %c1 {
140+
memref.store %c0_f32, %C[%i, %j] : memref<256x256xf32>
141+
memref.store %c0_f32, %C_ref[%i, %j] : memref<256x256xf32>
142+
}
143+
}
144+
// print input fror debug
145+
// %A_row_0 = memref.subview %A[1, 0][1, 256][1, 1] : memref<256x256xf16> to memref<1x256xf16, strided<[256, 1], offset: 256>>
146+
// %A_row_0_cast = memref.cast %A_row_0 : memref<1x256xf16, strided<[256, 1], offset: 256>> to memref<*xf16>
147+
// call @printMemrefF16(%A_row_0_cast) : (memref<*xf16>) -> ()
148+
149+
// run GPU
150+
%2 = call @test(%A, %B, %C) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> memref<256x256xf32>
151+
152+
call @cpu_reference(%A, %B, %C_ref) : (memref<256x256xf16>, memref<256x256xf16>, memref<256x256xf32>) -> ()
153+
154+
// %cast = memref.cast %A : memref<256x256xf16> to memref<*xf16>
155+
// call @printMemrefF16(%cast) : (memref<*xf16>) -> ()
156+
%cast_C = memref.cast %2 : memref<256x256xf32> to memref<*xf32>
157+
%cast_C_ref = memref.cast %C_ref : memref<256x256xf32> to memref<*xf32>
158+
// call @printMemrefF16(%cast_C) : (memref<*xf16>) -> ()
159+
// call @printMemrefF32(%cast_C_ref) : (memref<*xf32>) -> ()
160+
161+
%C_row_0 = memref.subview %C_ref[0, 0][1, 256][1, 1] : memref<256x256xf32> to memref<1x256xf32, strided<[256, 1], offset:0>>
162+
%C_row_0_cast = memref.cast %C_row_0 : memref<1x256xf32, strided<[256, 1], offset: 0>> to memref<*xf32>
163+
// call @printMemrefF32(%C_row_0_cast) : (memref<*xf32>) -> ()
164+
165+
%C_row_0_gpu = memref.subview %2[0, 0][1, 256][1, 1] : memref<256x256xf32> to memref<1x256xf32, strided<[256, 1], offset:0>>
166+
%C_row_0_cast_gpu = memref.cast %C_row_0_gpu : memref<1x256xf32, strided<[256, 1], offset: 0>> to memref<*xf32>
167+
// call @printMemrefF32(%C_row_0_cast_gpu) : (memref<*xf32>) -> ()
168+
169+
// CHECK: [ALLCLOSE: TRUE]
170+
call @printAllcloseF32(%cast_C, %cast_C_ref) : (memref<*xf32>, memref<*xf32>) -> ()
171+
memref.dealloc %A : memref<256x256xf16>
172+
memref.dealloc %B : memref<256x256xf16>
173+
memref.dealloc %C : memref<256x256xf32>
174+
memref.dealloc %C_ref : memref<256x256xf32>
175+
return
176+
}
177+
func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface}
178+
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
179+
func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface}
180+
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
181+
func.func private @fillResource1DRandomF16(memref<*xf16>, f32, f32, i1) attributes {llvm.emit_c_interface}
182+
}

test/Integration/Dialect/XeGPUToXeVM/xegpu-to-llvm.pp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
convert-cf-to-llvm
1212
convert-vector-to-llvm
1313
convert-arith-to-llvm
14+
expand-strided-metadata
1415
finalize-memref-to-llvm
1516
gpu-to-llvm
1617
reconcile-unrealized-casts

0 commit comments

Comments
 (0)