Skip to content

Commit 21d31e0

Browse files
authored
ggml-hexagon: fix swiglu failure at test-backend-ops (ggml-org#17344)
* refactor: use hvx_vec_exp_fp32_guard_inf for overflow handling in hvx_exp_f32 * feat: add fast sigmoid function with overflow guard for fp32 * refactor: replace hvx_vec_inverse_fp32 with hvx_vec_inverse_fp32_guard_inf for improved overflow handling * feat: enhance hvx_add_scalar_f32 with overflow handling using infinity guard * wip * add HVX_Vector_Alias wip * wip * fix: improve handling of src1 tensor in glu_swiglu_fp32_per_thread function * fix nc * wip * wip * handle nan at inverse * wip * fix neg * wip * rename * fix hvx_vec_inverse_fp32_guard_inf to handle infinity and NaN cases correctly * wip * fix hvx_vec_inverse_fp32_guard_inf to handle NaN cases correctly * wip * wip * wip * fix output sign
1 parent dd0f321 commit 21d31e0

File tree

5 files changed

+99
-45
lines changed

5 files changed

+99
-45
lines changed

ggml/src/ggml-hexagon/htp/act-ops.c

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -106,33 +106,32 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
106106
t1 = HAP_perf_get_qtimer_count();
107107

108108
int is_aligned = 1;
109-
int opt_path = 0;
110109
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
111110
is_aligned = 0;
112111
FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
113112
}
114-
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
115-
opt_path = 1;
116-
}
117113

118114
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
119115
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
120116
uint8_t * restrict data_dst = (uint8_t *) dst->data;
121117

122-
bool src1_valid = src1->ne[0];
118+
const bool src1_valid = src1->ne[0];
119+
const int nc = (src1_valid) ? ne00 : ne00 / 2;
123120
if (!src1_valid) {
124-
data_src1 = data_src0;
125-
src1_row_size = src0_row_size;
121+
const int32_t swapped = op_params[1];
122+
data_src1 = data_src0;
123+
src1_row_size = src0_row_size;
124+
125+
const size_t nc_in_bytes = nc * SIZEOF_FP32;
126+
data_src0 += swapped ? nc_in_bytes : 0;
127+
data_src1 += swapped ? 0 : nc_in_bytes;
126128
}
127129

128130
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
129131
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
130132
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
131133

132-
const int32_t swapped = op_params[1];
133-
134-
const int nc = (src1_valid) ? ne0 : ne0 / 2;
135-
134+
const bool opt_path = ((1 == is_aligned) && !(nb01 & (VLEN - 1)));
136135
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
137136
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
138137
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
@@ -142,12 +141,7 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
142141
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
143142
}
144143

145-
if (!src1_valid) {
146-
src0 += swapped ? nc : 0;
147-
src1 += swapped ? 0 : nc;
148-
}
149-
150-
if (1 == opt_path) {
144+
if (opt_path) {
151145
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc);
152146
hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1,
153147
(uint8_t *) dst, nc);
@@ -218,7 +212,7 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
218212
const float alpha = ((const float *) (op_params))[2];
219213
const float limit = ((const float *) (op_params))[3];
220214

221-
const int nc = (src1_valid) ? ne0 : ne0 / 2;
215+
const int nc = (src1_valid) ? ne00 : ne00 / 2;
222216

223217
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
224218
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));

ggml/src/ggml-hexagon/htp/hvx-exp.c

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,19 @@
1616
#include "hvx-utils.h"
1717
#include "ops-utils.h"
1818

19+
static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec) {
20+
static const float kInf = INFINITY;
21+
static const float kMaxExp = 88.02f; // log(INF)
22+
23+
const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
24+
const HVX_Vector inf = hvx_vec_splat_fp32(kInf);
25+
const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp);
26+
27+
HVX_Vector out = hvx_vec_exp_fp32(in_vec);
28+
29+
return Q6_V_vmux_QVV(pred0, inf, out);
30+
}
31+
1932
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
2033
int left_over = num_elems & (VLEN_FP32 - 1);
2134
int num_elems_whole = num_elems - left_over;
@@ -42,9 +55,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
4255
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
4356
if (true == negate) {
4457
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++);
45-
*p_vec_out++ = hvx_vec_exp_fp32(neg_vec_in);
58+
*p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in);
4659
} else {
47-
*p_vec_out++ = hvx_vec_exp_fp32(*p_vec_in1++);
60+
*p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++);
4861
}
4962
}
5063
} else {
@@ -54,9 +67,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
5467

5568
if (true == negate) {
5669
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
57-
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(neg_vec_in);
70+
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in);
5871
} else {
59-
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(in);
72+
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in);
6073
}
6174
}
6275
}
@@ -70,9 +83,9 @@ void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int
7083
if (true == negate) {
7184
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
7285

73-
vec_out = hvx_vec_exp_fp32(neg_vec_in);
86+
vec_out = hvx_vec_exp_fp32_guard(neg_vec_in);
7487
} else {
75-
vec_out = hvx_vec_exp_fp32(in);
88+
vec_out = hvx_vec_exp_fp32_guard(in);
7689
}
7790

7891
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);

ggml/src/ggml-hexagon/htp/hvx-inverse.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const
3838

3939
#pragma unroll(4)
4040
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
41-
*p_vec_out++ = hvx_vec_inverse_fp32(*p_vec_in++);
41+
*p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++);
4242
}
4343
} else {
4444
#pragma unroll(4)
4545
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
4646
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
47-
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32(in);
47+
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in);
4848
}
4949
}
5050

@@ -53,7 +53,7 @@ void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const
5353
float * dstf = (float *) dst + num_elems_whole;
5454

5555
HVX_Vector in = *(HVX_UVector *) srcf;
56-
HVX_Vector out = hvx_vec_inverse_fp32(in);
56+
HVX_Vector out = hvx_vec_inverse_fp32_guard(in);
5757

5858
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
5959
}

ggml/src/ggml-hexagon/htp/hvx-utils.c

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -401,25 +401,34 @@ void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
401401
FARF(HIGH, "hvx_add_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
402402
}
403403

404-
HVX_Vector val_vec = hvx_vec_splat_fp32(val);
404+
static const float kInf = INFINITY;
405+
const HVX_Vector inf = hvx_vec_splat_fp32(kInf);
406+
HVX_Vector val_vec = hvx_vec_splat_fp32(val);
405407

406408
if (0 == unaligned_loop) {
407409
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
408410
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
409411

410412
#pragma unroll(4)
411413
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
412-
HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, val_vec);
413-
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
414+
HVX_Vector in = *vec_in1++;
415+
const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in);
416+
HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
417+
v = Q6_Vsf_equals_Vqf32(v);
418+
v = Q6_V_vmux_QVV(pred_inf, inf, v);
419+
*vec_out++ = v;
414420
}
415421
} else {
416422
#pragma unroll(4)
417423
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
418424
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
419425

420-
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
426+
const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in);
427+
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
428+
out = Q6_Vsf_equals_Vqf32(out);
429+
out = Q6_V_vmux_QVV(pred_inf, inf, out);
421430

422-
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
431+
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = out;
423432
}
424433
}
425434

@@ -429,8 +438,12 @@ void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
429438

430439
HVX_Vector in = *(HVX_UVector *) srcf;
431440

432-
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
433-
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
441+
const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in);
442+
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
443+
out = Q6_Vsf_equals_Vqf32(out);
444+
out = Q6_V_vmux_QVV(pred_inf, inf, out);
445+
446+
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
434447
}
435448
}
436449

ggml/src/ggml-hexagon/htp/hvx-utils.h

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@
1212
#define VLEN_FP32 (VLEN / SIZEOF_FP32)
1313
#define VLEN_FP16 (VLEN / SIZEOF_FP16)
1414

15+
typedef union {
16+
HVX_Vector v;
17+
uint8_t b[VLEN];
18+
uint16_t h[VLEN_FP16];
19+
uint32_t w[VLEN_FP32];
20+
__fp16 fp16[VLEN_FP16];
21+
float fp32[VLEN_FP32];
22+
} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias;
23+
1524
static inline HVX_Vector hvx_vec_splat_fp32(float i) {
1625
union {
1726
float f;
@@ -243,19 +252,16 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3
243252
}
244253

245254
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
246-
union {
247-
HVX_Vector v;
248-
__fp16 d[64];
249-
} u = { .v = v };
255+
HVX_VectorAlias u = { .v = v };
250256

251257
const uint32_t n0 = n / 16;
252258
const uint32_t n1 = n % 16;
253259
int i = 0;
254260
for (; i < n0; i++) {
255-
htp_dump_fp16_line(pref, u.d + (16 * i), 16);
261+
htp_dump_fp16_line(pref, u.fp16 + (16 * i), 16);
256262
}
257263
if (n1) {
258-
htp_dump_fp16_line(pref, u.d + (16 * i), n1);
264+
htp_dump_fp16_line(pref, u.fp16 + (16 * i), n1);
259265
}
260266
}
261267

@@ -411,8 +417,8 @@ static inline HVX_Vector hvx_vec_fp32_reduce_sum_n(HVX_Vector in, unsigned int n
411417

412418
HVX_Vector sum = in, sum_t;
413419
while (width < total) {
414-
sum_t = Q6_V_vror_VR(sum, width); // rotate right
415-
sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum
420+
sum_t = Q6_V_vror_VR(sum, width); // rotate right
421+
sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum
416422
width = width << 1;
417423
}
418424
return sum;
@@ -491,7 +497,7 @@ static inline HVX_Vector hvx_vec_abs_fp16(HVX_Vector v) {
491497
static inline HVX_Vector hvx_vec_neg_fp16(HVX_Vector v) {
492498
// neg by setting the fp16 sign bit
493499
HVX_Vector mask = Q6_Vh_vsplat_R(0x8000);
494-
return Q6_V_vor_VV(v, mask);
500+
return Q6_V_vxor_VV(v, mask);
495501
}
496502

497503
static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
@@ -506,7 +512,7 @@ static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
506512
#else
507513
// neg by setting the fp32 sign bit
508514
HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
509-
return Q6_V_vor_VV(v, mask);
515+
return Q6_V_vxor_VV(v, mask);
510516
#endif // __HTP_ARCH__ > 75
511517
}
512518

@@ -720,6 +726,24 @@ static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) {
720726
return Q6_Vsf_equals_Vqf32(r_qf);
721727
}
722728

729+
static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf) {
730+
static const float kInf = INFINITY;
731+
static const uint32_t kNanMask = 0x7fffffff;
732+
static const uint32_t kNanMin = 0x7f800000;
733+
734+
const HVX_Vector inf = hvx_vec_splat_fp32(kInf);
735+
const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(inf, v_sf);
736+
737+
HVX_Vector out = hvx_vec_inverse_fp32(v_sf);
738+
739+
const HVX_Vector nan_mask = Q6_V_vsplat_R(kNanMask);
740+
const HVX_Vector nan_min = Q6_V_vsplat_R(kNanMin);
741+
HVX_Vector masked_out = Q6_V_vand_VV(out, nan_mask);
742+
const HVX_VectorPred pred = Q6_Q_vcmp_gtand_QVuwVuw(pred_inf, nan_min, masked_out);
743+
744+
return Q6_V_vmux_QVV(pred, out, Q6_V_vzero());
745+
}
746+
723747
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
724748
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
725749
#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267
@@ -934,6 +958,16 @@ static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) {
934958
return Q6_Vsf_equals_Vqf32(temp);
935959
}
936960

961+
static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v) {
962+
static const float kMaxExp = -88.02f; // log(INF)
963+
964+
const HVX_Vector max_exp = Q6_V_vsplat_R(*((uint32_t *) &kMaxExp));
965+
const HVX_VectorPred pred_inf = Q6_Q_vcmp_gt_VsfVsf(v, max_exp);
966+
967+
HVX_Vector out = hvx_vec_fast_sigmoid_fp32(v);
968+
return Q6_V_vmux_QVV(pred_inf, out, Q6_V_vzero());
969+
}
970+
937971
static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
938972
int step_of_1 = num_elems >> 5;
939973
int remaining = num_elems - step_of_1 * VLEN_FP32;
@@ -945,7 +979,7 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t *
945979

946980
#pragma unroll(4)
947981
for (int i = 0; i < step_of_1; i++) {
948-
v_dst[i] = hvx_vec_fast_sigmoid_fp32(v_src[i]);
982+
v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i]);
949983
}
950984
}
951985

0 commit comments

Comments
 (0)