Skip to content

Commit 353c410

Browse files
committed
perf: remove switch in SVE2 activation
Move activation-function selection to new templated dispatch activation function, which facilitates inlining of code to increase performance. This commit addresses the SVE2 kernels. Partially Resolves: COMPMID-8359 Signed-off-by: Dennis Wildmark <[email protected]> Change-Id: Ia209585a2fa8ee7ca6561b8942fd748c3b451bdf Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/14723 Reviewed-by: Dongsung Kim <[email protected]> Tested-by: Arm Jenkins <[email protected]> Comments-Addressed: Arm Jenkins <[email protected]> Benchmark: Arm Jenkins <[email protected]>
1 parent 937286e commit 353c410

File tree

6 files changed

+702
-471
lines changed

6 files changed

+702
-471
lines changed
Lines changed: 35 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2020-2023 Arm Limited.
2+
* Copyright (c) 2020-2023, 2025 Arm Limited.
33
*
44
* SPDX-License-Identifier: MIT
55
*
@@ -29,6 +29,7 @@
2929
#include "src/core/NEON/SVEAsymm.h"
3030
#include "src/core/NEON/SVEMath.h"
3131

32+
#include "qasymm8_impl.h"
3233
#include <arm_sve.h>
3334
#include <cmath>
3435
#include <cstddef>
@@ -52,189 +53,39 @@ void sve2_qasymm8_activation(const ITensor *src,
5253
Iterator input(src, win_collapsed);
5354
Iterator output(dst, win_collapsed);
5455

55-
const UniformQuantizationInfo qi_in = src->info()->quantization_info().uniform();
56-
const UniformQuantizationInfo qi_out = dst->info()->quantization_info().uniform();
57-
const auto va = svdup_n_u8(quantize_qasymm8(act_info.a(), qi_in));
58-
const auto vb = svdup_n_u8(quantize_qasymm8(act_info.b(), qi_in));
59-
const auto const_0 = quantize_qasymm8(0.f, qi_in);
60-
const auto vconst_0 = svdup_n_u8(const_0);
61-
const auto vconst_1 = svdup_n_f32(1.f);
62-
const auto va_f32 = svdup_n_f32(act_info.a());
63-
const auto vb_f32 = svdup_n_f32(act_info.b());
64-
65-
// Initialise scale/offset for re-quantization
66-
bool requant = true;
67-
if (qi_in.scale == qi_out.scale && qi_in.offset == qi_out.offset)
68-
{
69-
requant = false;
70-
}
71-
float s = qi_in.scale / qi_out.scale;
72-
float o = -qi_in.offset * s + qi_out.offset;
73-
auto vs = svdup_n_f32(s);
74-
auto vo = svdup_n_f32(o);
75-
76-
// Initialise scale/offset for re-quantization with int32_t
77-
const auto voffset_in = svdup_n_s32(qi_in.offset);
78-
int32_t s_s32 = round(s * (1 << 8), arm_compute::RoundingPolicy::TO_NEAREST_EVEN);
79-
int32_t o_s32 = round(o * (1 << 8), arm_compute::RoundingPolicy::TO_NEAREST_EVEN);
80-
const auto vs_s32 = svdup_n_s32(s_s32);
81-
const auto vo_s32 = svdup_n_s32(o_s32);
82-
83-
// Initialise scale/offset for re-quantization for leaky relu
84-
int32_t s_leaky_s32 = round(s * act_info.a() * (1 << 8), arm_compute::RoundingPolicy::TO_NEAREST_EVEN);
85-
int32_t o_leaky_s32 = round((-qi_in.offset * s * act_info.a() + qi_out.offset) * (1 << 8),
86-
arm_compute::RoundingPolicy::TO_NEAREST_EVEN);
87-
const auto vs_leaky_s32 = svdup_n_s32(s_leaky_s32);
88-
const auto vo_leaky_s32 = svdup_n_s32(o_leaky_s32);
89-
90-
execute_window_loop(
91-
win_collapsed,
92-
[&](const Coordinates &)
93-
{
94-
const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
95-
const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
96-
97-
svuint8_t tmp;
98-
99-
int x = window_start_x;
100-
svbool_t pg = svwhilelt_b8(x, window_end_x);
101-
do
102-
{
103-
const auto vin = svld1_u8(pg, input_ptr + x);
104-
if (act == ActivationLayerInfo::ActivationFunction::RELU)
105-
{
106-
// Perform activation
107-
tmp = svmax_u8_z(pg, vconst_0, vin);
108-
// Re-quantize to new output space
109-
tmp = requant ? svmla_qasymm8_z(pg, tmp, vs, vo) : tmp;
110-
}
111-
else if (act == ActivationLayerInfo::ActivationFunction::BOUNDED_RELU)
112-
{
113-
// Perform activation
114-
tmp = svmin_u8_z(pg, va, svmax_u8_z(pg, vconst_0, vin));
115-
// Re-quantize to new output space
116-
tmp = requant ? svmla_qasymm8_z(pg, tmp, vs, vo) : tmp;
117-
}
118-
else if (act == ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU)
119-
{
120-
// Perform activation
121-
tmp = svmin_u8_z(pg, va, svmax_u8_z(pg, vb, vin));
122-
// Re-quantize to new output space
123-
tmp = svmla_qasymm8_z(pg, tmp, vs, vo);
124-
}
125-
else if (act == ActivationLayerInfo::ActivationFunction::LOGISTIC)
126-
{
127-
// De-quantize
128-
const auto vin_deq = svdequantize_z(pg, vin, qi_in);
129-
// Perform activation
130-
const svfloat32x4_t tmp_dep = svcreate4_f32(
131-
svdiv_f32_z(
132-
pg, vconst_1,
133-
svadd_f32_z(pg, vconst_1, svexp_f32_z(pg, svneg_f32_z(pg, svget4_f32(vin_deq, 0))))),
134-
svdiv_f32_z(
135-
pg, vconst_1,
136-
svadd_f32_z(pg, vconst_1, svexp_f32_z(pg, svneg_f32_z(pg, svget4_f32(vin_deq, 1))))),
137-
svdiv_f32_z(
138-
pg, vconst_1,
139-
svadd_f32_z(pg, vconst_1, svexp_f32_z(pg, svneg_f32_z(pg, svget4_f32(vin_deq, 2))))),
140-
svdiv_f32_z(
141-
pg, vconst_1,
142-
svadd_f32_z(pg, vconst_1, svexp_f32_z(pg, svneg_f32_z(pg, svget4_f32(vin_deq, 3))))));
143-
144-
// Re-quantize to new output space
145-
tmp = svquantize_z(pg, tmp_dep, qi_out);
146-
}
147-
else if (act == ActivationLayerInfo::ActivationFunction::TANH)
148-
{
149-
// De-quantize
150-
const auto vin_deq = svdequantize_z(pg, vin, qi_in);
151-
// Perform activation
152-
const svfloat32x4_t tmp_dep = svcreate4_f32(
153-
svmul_f32_z(pg, va_f32, svtanh_f32_z(pg, svmul_f32_z(pg, svget4_f32(vin_deq, 0), vb_f32))),
154-
svmul_f32_z(pg, va_f32, svtanh_f32_z(pg, svmul_f32_z(pg, svget4_f32(vin_deq, 1), vb_f32))),
155-
svmul_f32_z(pg, va_f32, svtanh_f32_z(pg, svmul_f32_z(pg, svget4_f32(vin_deq, 2), vb_f32))),
156-
svmul_f32_z(pg, va_f32, svtanh_f32_z(pg, svmul_f32_z(pg, svget4_f32(vin_deq, 3), vb_f32))));
157-
158-
// Re-quantize to new output space
159-
tmp = svquantize_z(pg, tmp_dep, qi_out);
160-
}
161-
else if (act == ActivationLayerInfo::ActivationFunction::LEAKY_RELU)
162-
{
163-
svbool_t p0, p1, p2, p3;
164-
svint32x4_t tmp_dep;
165-
166-
// Expand to int32
167-
const svint32x4_t vin_s32 = svcreate4_s32(svreinterpret_s32_u32(svmovlb_u32(svmovlb_u16(vin))),
168-
svreinterpret_s32_u32(svmovlt_u32(svmovlb_u16(vin))),
169-
svreinterpret_s32_u32(svmovlb_u32(svmovlt_u16(vin))),
170-
svreinterpret_s32_u32(svmovlt_u32(svmovlt_u16(vin))));
171-
172-
// Compare elements to input offset
173-
if (qi_in.scale >= 0)
174-
{
175-
p0 = svcmplt_s32(pg, svget4_s32(vin_s32, 0), voffset_in);
176-
p1 = svcmplt_s32(pg, svget4_s32(vin_s32, 1), voffset_in);
177-
p2 = svcmplt_s32(pg, svget4_s32(vin_s32, 2), voffset_in);
178-
p3 = svcmplt_s32(pg, svget4_s32(vin_s32, 3), voffset_in);
179-
}
180-
else
181-
{
182-
p0 = svcmpgt_s32(pg, svget4_s32(vin_s32, 0), voffset_in);
183-
p1 = svcmpgt_s32(pg, svget4_s32(vin_s32, 1), voffset_in);
184-
p2 = svcmpgt_s32(pg, svget4_s32(vin_s32, 2), voffset_in);
185-
p3 = svcmpgt_s32(pg, svget4_s32(vin_s32, 3), voffset_in);
186-
}
187-
188-
// Multiply negative elements and requantize if necessary
189-
if (requant)
190-
{
191-
tmp_dep = svcreate4_s32(
192-
svasr_n_s32_m(pg,
193-
svmla_s32_m(pg, svsel(p0, vo_leaky_s32, vo_s32), svget4_s32(vin_s32, 0),
194-
svsel(p0, vs_leaky_s32, vs_s32)),
195-
8),
196-
svasr_n_s32_m(pg,
197-
svmla_s32_m(pg, svsel(p1, vo_leaky_s32, vo_s32), svget4_s32(vin_s32, 1),
198-
svsel(p1, vs_leaky_s32, vs_s32)),
199-
8),
200-
svasr_n_s32_m(pg,
201-
svmla_s32_m(pg, svsel(p2, vo_leaky_s32, vo_s32), svget4_s32(vin_s32, 2),
202-
svsel(p2, vs_leaky_s32, vs_s32)),
203-
8),
204-
svasr_n_s32_m(pg,
205-
svmla_s32_m(pg, svsel(p3, vo_leaky_s32, vo_s32), svget4_s32(vin_s32, 3),
206-
svsel(p3, vs_leaky_s32, vs_s32)),
207-
8));
208-
}
209-
else
210-
{
211-
tmp_dep = svcreate4_s32(
212-
svasr_n_s32_m(p0, svmad_s32_m(p0, svget4_s32(vin_s32, 0), vs_leaky_s32, vo_leaky_s32), 8),
213-
svasr_n_s32_m(p1, svmad_s32_m(p1, svget4_s32(vin_s32, 1), vs_leaky_s32, vo_leaky_s32), 8),
214-
svasr_n_s32_m(p2, svmad_s32_m(p2, svget4_s32(vin_s32, 2), vs_leaky_s32, vo_leaky_s32), 8),
215-
svasr_n_s32_m(p3, svmad_s32_m(p3, svget4_s32(vin_s32, 3), vs_leaky_s32, vo_leaky_s32), 8));
216-
}
217-
218-
// Convert uint32 vectors to uint16 vectors (with saturation)
219-
const auto v_low_u16 = svqxtunt_s32(svqxtunb_s32(svget4_s32(tmp_dep, 0)), svget4_s32(tmp_dep, 1));
220-
const auto v_high_u16 = svqxtunt_s32(svqxtunb_s32(svget4_s32(tmp_dep, 2)), svget4_s32(tmp_dep, 3));
221-
222-
// convert uint16 vectors to uint8 vectors (with saturation)
223-
tmp = svqxtnt_u16(svqxtnb_u16(v_low_u16), v_high_u16);
224-
}
225-
else
226-
{
227-
ARM_COMPUTE_ERROR("Unsupported activation function");
228-
}
229-
230-
svst1_u8(pg, output_ptr + x, tmp);
231-
232-
x += svcntb();
233-
pg = svwhilelt_b8(x, window_end_x);
234-
235-
} while (svptest_any(svptrue_b8(), pg));
236-
},
237-
input, output);
56+
const UniformQuantizationInfo qi_in = src->info()->quantization_info().uniform();
57+
const UniformQuantizationInfo qi_out = dst->info()->quantization_info().uniform();
58+
59+
dispatch_sve2_qasymm8_activation_function(act, act_info, qi_in, qi_out,
60+
[&](auto activation_function)
61+
{
62+
execute_window_loop(
63+
win_collapsed,
64+
[&](const Coordinates &)
65+
{
66+
const auto input_ptr =
67+
reinterpret_cast<const uint8_t *>(input.ptr());
68+
const auto output_ptr =
69+
reinterpret_cast<uint8_t *>(output.ptr());
70+
71+
svuint8_t tmp;
72+
73+
int x = window_start_x;
74+
svbool_t pg = svwhilelt_b8(x, window_end_x);
75+
do
76+
{
77+
const auto vin = svld1_u8(pg, input_ptr + x);
78+
tmp = activation_function(vin, pg);
79+
80+
svst1_u8(pg, output_ptr + x, tmp);
81+
82+
x += svcntb();
83+
pg = svwhilelt_b8(x, window_end_x);
84+
85+
} while (svptest_any(svptrue_b8(), pg));
86+
},
87+
input, output);
88+
});
23889
}
23990
} // namespace cpu
24091
} // namespace arm_compute

0 commit comments

Comments
 (0)