Skip to content

Commit 6f6e96d

Browse files
authored
Further optimized vector version for back to back macs (#2598)
1 parent 0756ce2 commit 6f6e96d

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

aie_kernels/aie2p/conv2dk14.cc

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -134,53 +134,55 @@ void conv2dk14_i8_vector(uint8_t *input, int8_t *kernels, int8_t *output,
134134
using MMUL8x8x8 = aie::mmul<8, 8, 8, uint8, int8>;
135135
::aie::set_saturation(
136136
aie::saturation_mode::saturate); // Needed to saturate properly to uint8
137-
// ::aie::set_rounding(
138-
// aie::rounding_mode::positive_inf); // Needed to saturate properly to
139-
// uint8
140137
::aie::set_rounding(
141138
aie::rounding_mode::symmetric_inf); // Needed to saturate properly to int8
142139

143-
// constexpr unsigned VecFactor = 16;
144-
145140
aie::vector<int8, 64> zero64 = aie::zeros<int8, 64>();
146141

147142
MMUL8x8x8 acc1 = aie::zeros<acc32, 64>();
143+
MMUL8x8x8 acc2 = aie::zeros<acc32, 64>();
148144
aie::vector<int8, 64> maxv = aie::broadcast<int8, 64>(127);
149145

150-
const int output_channels_div_8 = output_channels / 8;
151-
// const int output_channels_div_8 = 2;
152-
const int tiles_div_8 = input_width / kernel_width / 8;
153-
// const int tiles_div_8 = 2;
154-
const int pixels_div_2 = kernel_width * kernel_width / 2;
155-
// const int pixels_div_2 = 98; // kernel_width * kernel_width / 2; // 14*14/2
156-
// = 98
146+
const int output_channels_div_8 = output_channels / 8; // 2
147+
const int tiles_div_8 = input_width / kernel_width / 8; // 2
148+
const int tiles_div_16 = input_width / kernel_width / 16; // 1
149+
const int pixels_div_2 = kernel_width * kernel_width / 2; // 98
157150

158-
uint8_t *in_ptr = input;
159-
int8_t *k_ptr = kernels;
160-
int8_t *out_ptr = output;
151+
uint8_t *__restrict in_ptr_1 = input;
152+
uint8_t *__restrict in_ptr_2 = input + 98 * 64;
153+
int8_t *__restrict k_ptr = kernels;
154+
int8_t *__restrict out_ptr = output;
161155

162156
for (int k = 0; k < output_channels_div_8; k++) { // 2
163-
for (int j = 0; j < tiles_div_8; j++) { // 2
157+
for (int j = 0; j < tiles_div_16; j++) { // 2
164158
AIE_PREPARE_FOR_PIPELINING
165159
AIE_LOOP_MIN_ITERATION_COUNT(98)
166160
// AIE_LOOP_UNROLL_FULL
167-
for (int i = 0; i < pixels_div_2; i++) { // 98
168-
auto tmp_a1 = aie::load_v<64>(in_ptr); // 8 tiles x 2 pixels
169-
in_ptr += 64;
161+
for (int i = 0; i < pixels_div_2; i++) { // 98
162+
auto tmp_a1 = aie::load_v<64>(in_ptr_1); // 8 tiles x 2 pixels
163+
in_ptr_1 += 64;
170164
auto tmp_a2 = aie::load_v<64>(k_ptr); // 2 pixels x 8 channels
171-
k_ptr += 64;
172165
acc1.mac(tmp_a1, tmp_a2); // 8 tiles x 8 channels (for 2 pixels)
166+
auto tmp_b1 = aie::load_v<64>(in_ptr_2); // 8 tiles x 2 pixels
167+
in_ptr_2 += 64;
168+
acc2.mac(tmp_b1, tmp_a2); // 8 tiles x 8 channels (for 2 pixels)
169+
k_ptr += 64;
173170
}
174171
aie::vector<int8, 64> o1 = acc1.to_vector<int8>(scale);
175-
// aie::vector<int8, 64> o1 = acc1.to_vector<int8>(10);
172+
aie::vector<int8, 64> o2 = acc2.to_vector<int8>(scale);
176173
aie::store_v(out_ptr, o1);
177-
// aie::store_v(out_ptr, maxv);
174+
out_ptr += 64;
175+
aie::store_v(out_ptr, o2);
178176
out_ptr += 64;
179177
acc1 = aie::zeros<acc32, 64>();
178+
acc2 = aie::zeros<acc32, 64>();
180179
k_ptr -= 64 * pixels_div_2;
180+
in_ptr_1 += pixels_div_2 * 64;
181+
in_ptr_2 += pixels_div_2 * 64;
181182
}
182183
k_ptr += 64 * pixels_div_2;
183-
in_ptr -= tiles_div_8 * 64 * pixels_div_2;
184+
in_ptr_1 -= 2 * tiles_div_16 * pixels_div_2 * 64;
185+
in_ptr_2 -= 2 * tiles_div_16 * pixels_div_2 * 64;
184186
}
185187

186188
event1();

0 commit comments

Comments
 (0)