Skip to content

Commit fba6cc5

Browse files
authored
[CI][Test][Reduction][Xe2/3] Add 1D cross subgroup cross lane (#1157)
1 parent 4c89db1 commit fba6cc5

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// RUN: imex-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=workgroup zebin-chip=pvc" \
2+
// RUN: | mlir-runner \
3+
// RUN: --shared-libs=%mlir_levelzero_runtime \
4+
// RUN: --shared-libs=%mlir_runner_utils \
5+
// RUN: --shared-libs=%mlir_c_runner_utils \
6+
// RUN: --shared-libs=%irunner_utils \
7+
// RUN: --entry-point-result=void \
8+
// RUN: | FileCheck %s
9+
10+
#data_layout = #xegpu.layout<sg_layout = [4, 16], sg_data = [2, 16], inst_data = [1, 16], lane_layout= [1, 16], lane_data=[1, 1]>
11+
module attributes {gpu.container_module} {
12+
gpu.module @reduction {
13+
gpu.func @cross_sg_cross_lane_1D(%dst: memref<128xf32, 1>, %src: memref<8x256xf32, 1>) kernel {
14+
%dst1 = memref.memory_space_cast %dst : memref<128xf32, 1> to memref<128xf32>
15+
%dst_ptr_idx = memref.extract_aligned_pointer_as_index %dst1 : memref<128xf32> -> index
16+
%dst_ptr_i64 = arith.index_cast %dst_ptr_idx : index to i64
17+
%src1 = memref.memory_space_cast %src : memref<8x256xf32, 1> to memref<8x256xf32>
18+
%src_ptr_idx = memref.extract_aligned_pointer_as_index %src1 : memref<8x256xf32> -> index
19+
%src_ptr_i64 = arith.index_cast %src_ptr_idx : index to i64
20+
21+
%c256 = arith.constant dense<256> : vector<8xindex>
22+
%offset_i0 = vector.step : vector<8xindex>
23+
%offset_i = arith.muli %offset_i0, %c256 : vector<8xindex>
24+
%offset_j = vector.step : vector<256xindex>
25+
%offset_i_bcast = vector.broadcast %offset_i: vector<8xindex> to vector<256x8xindex>
26+
%offset_i_bcast_t = vector.transpose %offset_i_bcast, [1, 0]: vector<256x8xindex> to vector<8x256xindex>
27+
%offset_j_bcast = vector.broadcast %offset_j : vector<256xindex> to vector<8x256xindex>
28+
%offset_ld = arith.addi %offset_i_bcast_t, %offset_j_bcast : vector<8x256xindex>
29+
30+
%mask_ld = arith.constant dense<1> : vector<8x256xi1>
31+
%val = xegpu.load %src_ptr_i64[%offset_ld], %mask_ld : i64, vector<8x256xindex>, vector<8x256xi1> -> vector<8x256xf32>
32+
%acc = arith.constant dense<0.0> : vector<8xf32>
33+
%res = vector.multi_reduction <add>, %val, %acc [1] : vector<8x256xf32> to vector<8xf32>
34+
35+
%offset = vector.step : vector<8xindex>
36+
%mask = arith.constant dense<1> : vector<8xi1>
37+
xegpu.store %res, %dst_ptr_i64[%offset], %mask { layout = #xegpu.slice<#data_layout, dims = [1]> } : vector<8xf32>, i64, vector<8xindex>, vector<8xi1>
38+
gpu.return
39+
}
40+
}
41+
42+
func.func @test(%dst : memref<128xf32>, %src : memref<8x256xf32>) attributes {llvm.emit_c_interface} {
43+
%c1 = arith.constant 1 : index
44+
%c16 = arith.constant 16 : index
45+
46+
%c32 = arith.constant 32 : index
47+
%c64 = arith.constant 64 : index
48+
%c128 = arith.constant 128 : index
49+
50+
%c1024 = arith.constant 1024 : index // 4 * 16 * 16
51+
52+
%c2 = arith.constant 2 : index
53+
%c4 = arith.constant 4 : index
54+
55+
%stream0_0 = gpu.wait async
56+
57+
%gpu_memref_dst, %stream0_1 = gpu.alloc async [%stream0_0] () : memref<128xf32>
58+
%stream0_2 = gpu.memcpy async [%stream0_1] %gpu_memref_dst, %dst : memref<128xf32>, memref<128xf32>
59+
60+
%gpu_memref_src, %stream0_3 = gpu.alloc async [%stream0_2] () : memref<8x256xf32>
61+
%stream0_4 = gpu.memcpy async [%stream0_3] %gpu_memref_src, %src : memref<8x256xf32>, memref<8x256xf32>
62+
63+
64+
%dst_ptr_idx = memref.extract_aligned_pointer_as_index %gpu_memref_dst : memref<128xf32> -> index
65+
%dst_ptr_i64 = arith.index_cast %dst_ptr_idx : index to i64
66+
67+
%src_ptr_idx = memref.extract_aligned_pointer_as_index %gpu_memref_src : memref<8x256xf32> -> index
68+
%src_ptr_i64 = arith.index_cast %src_ptr_idx : index to i64
69+
70+
%gpu_memref_dst_casted = memref.memory_space_cast %gpu_memref_dst : memref<128xf32> to memref<128xf32, 1>
71+
%gpu_memref_src_casted = memref.memory_space_cast %gpu_memref_src : memref<8x256xf32> to memref<8x256xf32, 1>
72+
73+
%stream0_5 = gpu.launch_func async[%stream0_4] @reduction::@cross_sg_cross_lane_1D blocks in (%c1, %c1, %c1) threads in (%c1024, %c1, %c1) args(%gpu_memref_dst_casted : memref<128xf32, 1>, %gpu_memref_src_casted : memref<8x256xf32, 1>)
74+
75+
%stream0_6 = gpu.memcpy async [%stream0_5] %dst, %gpu_memref_dst : memref<128xf32>, memref<128xf32>
76+
%stream0_8 = gpu.dealloc async [%stream0_6] %gpu_memref_dst : memref<128xf32>
77+
gpu.wait [%stream0_8]
78+
return
79+
}
80+
81+
func.func @main() attributes {llvm.emit_c_interface} {
82+
%dst = memref.alloc() : memref<128xf32>
83+
%src = memref.alloc() : memref<8x256xf32>
84+
85+
%c0 = arith.constant 0 : index
86+
%c1 = arith.constant 1 : index
87+
%c8 = arith.constant 8 : index
88+
%c256 = arith.constant 256 : index
89+
%c128 = arith.constant 128 : index
90+
91+
%c0_f32 = arith.constant 0. : f32
92+
scf.for %i = %c0 to %c8 step %c1 {
93+
scf.for %j = %c0 to %c256 step %c1 {
94+
%i_f32 = arith.index_cast %i : index to i32
95+
%j_f32 = arith.index_cast %j : index to i32
96+
%i_float = arith.sitofp %i_f32 : i32 to f32
97+
%j_float = arith.sitofp %j_f32 : i32 to f32
98+
%c1000_f32 = arith.constant 1000.0 : f32
99+
%j_scaled = arith.divf %j_float, %c1000_f32 : f32
100+
%val = arith.addf %i_float, %j_scaled : f32
101+
// Input is in format (#row_idx).(#col_idx/1000.0)
102+
memref.store %i_float, %src[%i, %j] : memref<8x256xf32>
103+
}
104+
}
105+
%c0_i64 = arith.constant 0 : i64
106+
107+
scf.for %i = %c0 to %c128 step %c1 {
108+
memref.store %c0_f32, %dst[%i] : memref<128xf32>
109+
}
110+
call @test(%dst, %src) : (memref<128xf32>, memref<8x256xf32>) -> ()
111+
%dst_cast = memref.cast %dst : memref<128xf32> to memref<*xf32>
112+
%src_cast = memref.cast %src : memref<8x256xf32> to memref<*xf32>
113+
114+
// CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
115+
// CHECK-NEXT: [0, 256, 512, 768, 1024, 1280, 1536, 1792, 0,
116+
call @printMemrefF32(%dst_cast) : (memref<*xf32>) -> ()
117+
return
118+
}
119+
func.func private @printMemrefF32(%ptr : memref<*xf32>) attributes { llvm.emit_c_interface }
120+
}

0 commit comments

Comments
 (0)