Skip to content

Commit e17e8fc

Browse files
authored
[e2e] Add cross lane intra subgroup softmax kernel. (#1154)
1 parent 3ca2a37 commit e17e8fc

File tree

1 file changed

+159
-0
lines changed

1 file changed

+159
-0
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// RUN: imex-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=workgroup" \
2+
// RUN: | mlir-runner \
3+
// RUN: --shared-libs=%mlir_levelzero_runtime \
4+
// RUN: --shared-libs=%mlir_runner_utils \
5+
// RUN: --shared-libs=%mlir_c_runner_utils \
6+
// RUN: --shared-libs=%irunner_utils \
7+
// RUN: --entry-point-result=void \
8+
// RUN: | FileCheck %s
9+
10+
#lo_sg_8x1_data_8x64 = #xegpu.layout<sg_layout = [8, 1], sg_data = [8, 64], order = [1, 0]>
11+
#lo_sg_1x8_data_64x8 = #xegpu.layout<sg_layout = [1, 8], sg_data = [64, 8], order = [0, 1]>
12+
module attributes {gpu.container_module} {
13+
func.func @main() attributes {llvm.emit_c_interface} {
14+
%c0 = arith.constant 0 : index
15+
%c1 = arith.constant 1 : index
16+
%c16 = arith.constant 16 : index
17+
%c64 = arith.constant 64 : index
18+
%c128 = arith.constant 128 : index
19+
%c1024 = arith.constant 1024 : index
20+
%c0_f32 = arith.constant 0.0 : f32
21+
%cf_lower = arith.constant -0.5 : f32
22+
%cf_upper = arith.constant 0.5 : f32
23+
%c_gen_int = arith.constant 0 : i1
24+
25+
// Allocate and randomly initialize input in [-0.5, 0.5]
26+
%arg0 = memref.alloc() : memref<1024x64xf32>
27+
%arg0_cast = memref.cast %arg0 : memref<1024x64xf32> to memref<*xf32>
28+
call @fillResource1DRandomF32(%arg0_cast, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> ()
29+
30+
// Allocate and zero-initialize GPU output buffer (CPU side)
31+
%arg1 = memref.alloc() : memref<1024x64xf32>
32+
scf.for %i = %c0 to %c1024 step %c1 {
33+
scf.for %j = %c0 to %c64 step %c1 {
34+
memref.store %c0_f32, %arg1[%i, %j] : memref<1024x64xf32>
35+
}
36+
}
37+
38+
// Allocate GPU buffers and copy input/output to GPU
39+
%arg0_gpu = gpu.alloc () : memref<1024x64xf32>
40+
gpu.memcpy %arg0_gpu, %arg0 : memref<1024x64xf32>, memref<1024x64xf32>
41+
%arg1_gpu = gpu.alloc () : memref<1024x64xf32>
42+
gpu.memcpy %arg1_gpu, %arg1 : memref<1024x64xf32>, memref<1024x64xf32>
43+
44+
// Launch kernel and wait for completion
45+
gpu.launch_func @main_kernel::@main_kernel
46+
blocks in (%c16, %c1, %c1) threads in (%c128, %c1, %c1)
47+
args(%arg0_gpu : memref<1024x64xf32>, %arg1_gpu : memref<1024x64xf32>)
48+
gpu.wait
49+
50+
// Copy result back to host
51+
gpu.memcpy %arg1, %arg1_gpu : memref<1024x64xf32>, memref<1024x64xf32>
52+
gpu.dealloc %arg0_gpu : memref<1024x64xf32>
53+
gpu.dealloc %arg1_gpu : memref<1024x64xf32>
54+
55+
// Compute CPU reference
56+
%cpu_out = memref.alloc() : memref<1024x64xf32>
57+
scf.for %i = %c0 to %c1024 step %c1 {
58+
scf.for %j = %c0 to %c64 step %c1 {
59+
memref.store %c0_f32, %cpu_out[%i, %j] : memref<1024x64xf32>
60+
}
61+
}
62+
call @cpu_reference(%arg0, %cpu_out) : (memref<1024x64xf32>, memref<1024x64xf32>) -> ()
63+
64+
// Compare GPU and CPU results
65+
%arg1_star = memref.cast %arg1 : memref<1024x64xf32> to memref<*xf32>
66+
%cpu_out_star = memref.cast %cpu_out : memref<1024x64xf32> to memref<*xf32>
67+
// CHECK: [ALLCLOSE: TRUE]
68+
call @printAllcloseF32(%arg1_star, %cpu_out_star) : (memref<*xf32>, memref<*xf32>) -> ()
69+
// Debug print the first row of GPU and CPU output
70+
// %gpu_row_0 = memref.subview %arg1[%c0, %c0] [%c1, %c64] [%c1, %c1]
71+
// : memref<1024x64xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
72+
// %gpu_row_0_star = memref.cast %gpu_row_0
73+
// : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<*xf32>
74+
// call @printMemrefF32(%gpu_row_0_star) : (memref<*xf32>) -> ()
75+
76+
// %cpu_row_0 = memref.subview %cpu_out[%c0, %c0] [%c1, %c64] [%c1, %c1]
77+
// : memref<1024x64xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
78+
// %cpu_row_0_star = memref.cast %cpu_row_0
79+
// : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<*xf32>
80+
// call @printMemrefF32(%cpu_row_0_star) : (memref<*xf32>) -> ()
81+
82+
memref.dealloc %arg0 : memref<1024x64xf32>
83+
memref.dealloc %arg1 : memref<1024x64xf32>
84+
memref.dealloc %cpu_out : memref<1024x64xf32>
85+
return
86+
}
87+
func.func @cpu_reference(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) attributes {llvm.emit_c_interface} {
88+
%c0 = arith.constant 0 : index
89+
%c1 = arith.constant 1 : index
90+
%c1024 = arith.constant 1024 : index
91+
%c64 = arith.constant 64 : index
92+
%neg_inf = arith.constant 0xFF800000 : f32
93+
%zero = arith.constant 0.0 : f32
94+
// Iterate over each row
95+
scf.for %row = %c0 to %c1024 step %c1 {
96+
// Step 1: find row max
97+
%max = scf.for %col = %c0 to %c64 step %c1 iter_args(%cur_max = %neg_inf) -> f32 {
98+
%val = memref.load %arg0[%row, %col] : memref<1024x64xf32>
99+
%new_max = arith.maximumf %cur_max, %val : f32
100+
scf.yield %new_max : f32
101+
}
102+
// Step 2: compute exp(x - max) and accumulate sum
103+
%sum = scf.for %col = %c0 to %c64 step %c1 iter_args(%cur_sum = %zero) -> f32 {
104+
%val = memref.load %arg0[%row, %col] : memref<1024x64xf32>
105+
%shifted = arith.subf %val, %max : f32
106+
%exp_val = math.exp %shifted : f32
107+
memref.store %exp_val, %arg1[%row, %col] : memref<1024x64xf32>
108+
%new_sum = arith.addf %cur_sum, %exp_val : f32
109+
scf.yield %new_sum : f32
110+
}
111+
// Step 3: divide by sum
112+
scf.for %col = %c0 to %c64 step %c1 {
113+
%exp_val = memref.load %arg1[%row, %col] : memref<1024x64xf32>
114+
%result = arith.divf %exp_val, %sum : f32
115+
memref.store %result, %arg1[%row, %col] : memref<1024x64xf32>
116+
}
117+
}
118+
return
119+
}
120+
gpu.module @main_kernel [#xevm.target<chip = "pvc">] {
121+
gpu.func @main_kernel(%arg0: memref<1024x64xf32>, %arg1: memref<1024x64xf32>) kernel attributes {intel_reqd_sub_group_size = 16 : i32, known_block_size = array<i32: 128, 1, 1>, known_grid_size = array<i32: 16, 1, 1>} {
122+
%cst = arith.constant {layout_result_0 = #xegpu.slice<#lo_sg_8x1_data_8x64, dims = [1]>} dense<0.000000e+00> : vector<64xf32>
123+
%cst_0 = arith.constant {layout_result_0 = #xegpu.slice<#lo_sg_8x1_data_8x64, dims = [1]>} dense<0xFF800000> : vector<64xf32>
124+
%c64 = arith.constant 64 : index
125+
%block_id_x = gpu.block_id x
126+
%0 = arith.muli %block_id_x, %c64 overflow<nsw> : index
127+
%1 = xegpu.create_nd_tdesc %arg0 : memref<1024x64xf32>
128+
-> !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #lo_sg_8x1_data_8x64>
129+
%2 = xegpu.load_nd %1[%0, 0] <{layout = #lo_sg_8x1_data_8x64}> :
130+
!xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #lo_sg_8x1_data_8x64> -> vector<64x64xf32>
131+
%4 = vector.multi_reduction <maximumf>, %2, %cst_0
132+
{layout_result_0 = #xegpu.slice<#lo_sg_8x1_data_8x64, dims = [1]>} [1] : vector<64x64xf32> to vector<64xf32>
133+
%5 = xegpu.convert_layout %4 <
134+
{input_layout = #xegpu.slice<#lo_sg_8x1_data_8x64, dims = [1]>,
135+
target_layout = #xegpu.slice<#lo_sg_1x8_data_64x8, dims = [0]>}> : vector<64xf32>
136+
%6 = vector.broadcast %5 {layout_result_0 = #lo_sg_1x8_data_64x8} : vector<64xf32> to vector<64x64xf32>
137+
%7 = vector.transpose %6, [1, 0] {layout_result_0 = #lo_sg_8x1_data_8x64} : vector<64x64xf32> to vector<64x64xf32>
138+
%8 = arith.subf %2, %7 {layout_result_0 = #lo_sg_8x1_data_8x64} : vector<64x64xf32>
139+
%9 = math.exp %8 {layout_result_0 = #lo_sg_8x1_data_8x64} : vector<64x64xf32>
140+
%11 = vector.multi_reduction <add>, %9, %cst
141+
{layout_result_0 = #xegpu.slice<#lo_sg_8x1_data_8x64, dims = [1]>} [1] : vector<64x64xf32> to vector<64xf32>
142+
%12 = xegpu.convert_layout %11 <
143+
{input_layout = #xegpu.slice<#lo_sg_8x1_data_8x64, dims = [1]>,
144+
target_layout = #xegpu.slice<#lo_sg_1x8_data_64x8, dims = [0]>}> : vector<64xf32>
145+
%13 = vector.broadcast %12
146+
{layout_result_0 = #lo_sg_1x8_data_64x8} : vector<64xf32> to vector<64x64xf32>
147+
%14 = vector.transpose %13, [1, 0] {layout_result_0 = #lo_sg_8x1_data_8x64} : vector<64x64xf32> to vector<64x64xf32>
148+
%15 = arith.divf %9, %14 {layout_result_0 = #lo_sg_8x1_data_8x64} : vector<64x64xf32>
149+
%16 = xegpu.create_nd_tdesc %arg1 : memref<1024x64xf32> ->
150+
!xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #lo_sg_8x1_data_8x64>
151+
xegpu.store_nd %15, %16[%0, 0] <{layout = #lo_sg_8x1_data_8x64}>
152+
: vector<64x64xf32>, !xegpu.tensor_desc<64x64xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #lo_sg_8x1_data_8x64>
153+
gpu.return
154+
}
155+
}
156+
func.func private @fillResource1DRandomF32(memref<*xf32>, f32, f32, i1) attributes {llvm.emit_c_interface}
157+
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
158+
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
159+
}

0 commit comments

Comments
 (0)