Skip to content

Commit 3d16b5b

Browse files
committed
[Reduction] Verify different sizes and types work for reduction.
Patch adds tests to make sure all the tests with `vector.multi_reduction` generate successfully pass Peano legalizer and generate efficient vectorized code. This patch checks only the IREE side to keep the dependency minimun on Peano. (Depends on Peano: 1. Xilinx/llvm-aie#548 2. Xilinx/llvm-aie#557 ) 1. `reassociateFpReductions=true` is must else code is scalarized. This flag could be added into the IREE vectorization pipeline to trigger automatically. 2. bf16/i32/f32 all types with different sizes work now.
1 parent 24dd68d commit 3d16b5b

File tree

2 files changed

+178
-0
lines changed

2 files changed

+178
-0
lines changed

compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ iree_lit_test_suite(
7373
"lowering_strategy_objectfifo_npu4.mlir"
7474
"lowering_strategy_softmax.mlir"
7575
"map_forall_to_cores.mlir"
76+
"multi_reduction_to_llvm_intrinsics.mlir"
7677
"multi_reduction_to_reduction.mlir"
7778
"none_access_to_temporary_buffer.mlir"
7879
"normalize_loop_bounds.mlir"
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-amdaie-vectorization), iree-convert-to-llvm{reassociateFpReductions=true})' %s | mlir-translate --mlir-to-llvmir | FileCheck %s
2+
3+
// Tests for `vector.multi_reduction` operation and it's lowering with
4+
// different sizes and data types. Tries to check if the corresponding llvm
5+
// reduction function is generated. `reassoc` must be present for f32 and bf16
6+
// for code to vectorize in Peano
7+
8+
////////////////////////1D//////////////////////
9+
10+
// = 512 bits
11+
// CHECK-LABEL: @multi_reduction_1d_16_i32
12+
// CHECK: @llvm.vector.reduce.add.v16i32(<16 x i32> %{{.*}})
13+
func.func private @multi_reduction_1d_16_i32(%v : vector<16xi32>, %acc: i32) -> i32 {
14+
%0 = vector.multi_reduction <add>, %v, %acc[0] : vector<16xi32> to i32
15+
return %0 : i32
16+
}
17+
// < 512 bits
18+
// CHECK-LABEL: @multi_reduction_1d_8_i32
19+
// CHECK: @llvm.vector.reduce.add.v8i32(<8 x i32> %{{.*}})
20+
func.func private @multi_reduction_1d_8_i32(%v : vector<8xi32>, %acc: i32) -> i32 {
21+
%0 = vector.multi_reduction <add>, %v, %acc[0] : vector<8xi32> to i32
22+
return %0 : i32
23+
}
24+
// > 512 bits
25+
// CHECK-LABEL: @multi_reduction_1d_64_i32
26+
// CHECK: @llvm.vector.reduce.add.v64i32(<64 x i32> %{{.*}})
27+
func.func private @multi_reduction_1d_64_i32(%v : vector<64xi32>, %acc: i32) -> i32 {
28+
%0 = vector.multi_reduction <add>, %v, %acc[0] : vector<64xi32> to i32
29+
return %0 : i32
30+
}
31+
32+
// 256
33+
// CHECK-LABEL: @multi_reduction_1d_16_bf16
34+
// CHECK: call reassoc bfloat @llvm.vector.reduce.fadd.v16bf16(bfloat %{{.*}}, <16 x bfloat> %{{.*}})
35+
func.func private @multi_reduction_1d_16_bf16(%v : vector<16xbf16>, %acc: bf16) -> bf16 {
36+
%0 = vector.multi_reduction <add>, %v, %acc[0] : vector<16xbf16> to bf16
37+
return %0 : bf16
38+
}
39+
// 512
40+
// CHECK-LABEL: @multi_reduction_1d_32_bf16
41+
// CHECK: call reassoc bfloat @llvm.vector.reduce.fadd.v32bf16(bfloat %{{.*}}, <32 x bfloat> %{{.*}})
42+
func.func private @multi_reduction_1d_32_bf16(%v : vector<32xbf16>, %acc: bf16) -> bf16 {
43+
%0 = vector.multi_reduction <add>, %v, %acc[0] : vector<32xbf16> to bf16
44+
return %0 : bf16
45+
}
46+
47+
// 1024
48+
// CHECK-LABEL: @multi_reduction_1d_64_bf16
49+
// CHECK: call reassoc bfloat @llvm.vector.reduce.fadd.v64bf16(bfloat %{{.*}}, <64 x bfloat> %{{.*}})
50+
func.func private @multi_reduction_1d_64_bf16(%v : vector<64xbf16>, %acc: bf16) -> bf16 {
51+
%0 = vector.multi_reduction <add>, %v, %acc[0] : vector<64xbf16> to bf16
52+
return %0 : bf16
53+
}
54+
55+
56+
// F32
57+
// moreElementsIf()
58+
// CHECK-LABEL: @multi_reduction_1d_16_f32
59+
// CHECK: call reassoc float @llvm.vector.reduce.fadd.v16f32(float %{{.*}}, <16 x float> %{{.*}})
60+
func.func private @multi_reduction_1d_16_f32(%v : vector<16xf32>, %acc: f32) -> f32 {
61+
%0 = vector.multi_reduction <add>, %v, %acc[0] : vector<16xf32> to f32
62+
return %0 : f32
63+
}
64+
65+
// CHECK-LABEL: @multi_reduction_1d_32_f32
66+
// CHECK: call reassoc float @llvm.vector.reduce.fadd.v32f32(float %{{.*}}, <32 x float> %{{.*}})
67+
func.func private @multi_reduction_1d_32_f32(%v : vector<32xf32>, %acc: f32) -> f32 {
68+
%0 = vector.multi_reduction <add>, %v, %acc[0] : vector<32xf32> to f32
69+
return %0 : f32
70+
}
71+
72+
// CHECK-LABEL: @multi_reduction_1d_64_f32
73+
// CHECK: call reassoc float @llvm.vector.reduce.fadd.v64f32(float %{{.*}}, <64 x float> %{{.*}})
74+
func.func private @multi_reduction_1d_64_f32(%v : vector<64xf32>, %acc: f32) -> f32 {
75+
%0 = vector.multi_reduction <add>, %v, %acc[0] : vector<64xf32> to f32
76+
return %0 : f32
77+
}
78+
79+
////////////////////////2D//////////////////////
80+
// Expected: Binary tree reduction(Converts 2D to 1D) + vector intrinsic
81+
82+
83+
// i32
84+
// = 512 bits (4x4 = 16 elements)
85+
// CHECK-LABEL: @multi_reduction_2d_4x4_i32
86+
// CHECK-COUNT-4: extractvalue
87+
// CHECK: shufflevector
88+
// CHECK-NEXT: shufflevector
89+
// CHECK: @llvm.vector.reduce.add.v16i32(<16 x i32> %{{.*}})
90+
func.func private @multi_reduction_2d_4x4_i32(%v : vector<4x4xi32>, %acc: i32) -> i32 {
91+
%0 = vector.multi_reduction <add>, %v, %acc[0, 1] : vector<4x4xi32> to i32
92+
return %0 : i32
93+
}
94+
// < 512 bits (2x4 = 8 elements)
95+
// CHECK-LABEL: @multi_reduction_2d_2x4_i32
96+
// CHECK-COUNT-2: extractvalue
97+
// CHECK: shufflevector
98+
// CHECK-NEXT: shufflevector
99+
// CHECK: @llvm.vector.reduce.add.v8i32(<8 x i32> %{{.*}})
100+
func.func private @multi_reduction_2d_2x4_i32(%v : vector<2x4xi32>, %acc: i32) -> i32 {
101+
%0 = vector.multi_reduction <add>, %v, %acc[0, 1] : vector<2x4xi32> to i32
102+
return %0 : i32
103+
}
104+
// > 512 bits (8x8 = 64 elements)
105+
// CHECK-LABEL: @multi_reduction_2d_8x8_i32
106+
// CHECK-COUNT-8: extractvalue
107+
// CHECK: shufflevector
108+
// CHECK-NEXT: shufflevector
109+
// CHECK: @llvm.vector.reduce.add.v64i32(<64 x i32> %{{.*}})
110+
func.func private @multi_reduction_2d_8x8_i32(%v : vector<8x8xi32>, %acc: i32) -> i32 {
111+
%0 = vector.multi_reduction <add>, %v, %acc[0, 1] : vector<8x8xi32> to i32
112+
return %0 : i32
113+
}
114+
115+
// bf16
116+
// 256 bits (4x4 = 16 elements)
117+
// CHECK-LABEL: @multi_reduction_2d_4x4_bf16
118+
// CHECK-COUNT-4: extractvalue
119+
// CHECK: shufflevector
120+
// CHECK-NEXT: shufflevector
121+
// CHECK: call reassoc bfloat @llvm.vector.reduce.fadd.v16bf16(bfloat %{{.*}}, <16 x bfloat> %{{.*}})
122+
func.func private @multi_reduction_2d_4x4_bf16(%v : vector<4x4xbf16>, %acc: bf16) -> bf16 {
123+
%0 = vector.multi_reduction <add>, %v, %acc[0, 1] : vector<4x4xbf16> to bf16
124+
return %0 : bf16
125+
}
126+
// 512 bits (8x4 = 32 elements)
127+
// CHECK-LABEL: @multi_reduction_2d_8x4_bf16
128+
// CHECK-COUNT-8: extractvalue
129+
// CHECK: shufflevector
130+
// CHECK-NEXT: shufflevector
131+
// CHECK: call reassoc bfloat @llvm.vector.reduce.fadd.v32bf16(bfloat %{{.*}}, <32 x bfloat> %{{.*}})
132+
func.func private @multi_reduction_2d_8x4_bf16(%v : vector<8x4xbf16>, %acc: bf16) -> bf16 {
133+
%0 = vector.multi_reduction <add>, %v, %acc[0, 1] : vector<8x4xbf16> to bf16
134+
return %0 : bf16
135+
}
136+
// 1024 bits (8x8 = 64 elements)
137+
// CHECK-LABEL: @multi_reduction_2d_8x8_bf16
138+
// CHECK-COUNT-8: extractvalue
139+
// CHECK: shufflevector
140+
// CHECK-NEXT: shufflevector
141+
// CHECK: call reassoc bfloat @llvm.vector.reduce.fadd.v64bf16(bfloat %{{.*}}, <64 x bfloat> %{{.*}})
142+
func.func private @multi_reduction_2d_8x8_bf16(%v : vector<8x8xbf16>, %acc: bf16) -> bf16 {
143+
%0 = vector.multi_reduction <add>, %v, %acc[0, 1] : vector<8x8xbf16> to bf16
144+
return %0 : bf16
145+
}
146+
147+
// f32
148+
// 512 bits (4x4 = 16 elements)
149+
// CHECK-LABEL: @multi_reduction_2d_4x4_f32
150+
// CHECK-COUNT-4: extractvalue
151+
// CHECK: shufflevector
152+
// CHECK-NEXT: shufflevector
153+
// CHECK: call reassoc float @llvm.vector.reduce.fadd.v16f32(float %{{.*}}, <16 x float> %{{.*}})
154+
func.func private @multi_reduction_2d_4x4_f32(%v : vector<4x4xf32>, %acc: f32) -> f32 {
155+
%0 = vector.multi_reduction <add>, %v, %acc[0, 1] : vector<4x4xf32> to f32
156+
return %0 : f32
157+
}
158+
// 1024 bits (8x4 = 32 elements)
159+
// CHECK-LABEL: @multi_reduction_2d_8x4_f32
160+
// CHECK-COUNT-8: extractvalue
161+
// CHECK: shufflevector
162+
// CHECK-NEXT: shufflevector
163+
// CHECK: call reassoc float @llvm.vector.reduce.fadd.v32f32(float %{{.*}}, <32 x float> %{{.*}})
164+
func.func private @multi_reduction_2d_8x4_f32(%v : vector<8x4xf32>, %acc: f32) -> f32 {
165+
%0 = vector.multi_reduction <add>, %v, %acc[0, 1] : vector<8x4xf32> to f32
166+
return %0 : f32
167+
}
168+
// 2048 bits (8x8 = 64 elements)
169+
// CHECK-LABEL: @multi_reduction_2d_8x8_f32
170+
// CHECK-COUNT-8: extractvalue
171+
// CHECK: shufflevector
172+
// CHECK-NEXT: shufflevector
173+
// CHECK: call reassoc float @llvm.vector.reduce.fadd.v64f32(float %{{.*}}, <64 x float> %{{.*}})
174+
func.func private @multi_reduction_2d_8x8_f32(%v : vector<8x8xf32>, %acc: f32) -> f32 {
175+
%0 = vector.multi_reduction <add>, %v, %acc[0, 1] : vector<8x8xf32> to f32
176+
return %0 : f32
177+
}

0 commit comments

Comments
 (0)