6262#define NOINLINE __attribute__ ((__noinline__))
6363#endif
6464
65- #if defined(__ARM_NEON) || defined(__AVX512F__)
65+ #if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
6666#define VECTOR_REGISTERS 32
6767#else
6868#define VECTOR_REGISTERS 16
@@ -109,6 +109,12 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
109109inline float16x8_t mul (float16x8_t x, float16x8_t y) { return vmulq_f16 (x, y); }
110110#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
111111
112+ #if defined(__VXE__) || defined(__VXE2__)
113+ inline float32x4_t add (float32x4_t x, float32x4_t y) { return vec_add (x, y); }
114+ inline float32x4_t sub (float32x4_t x, float32x4_t y) { return vec_sub (x, y); }
115+ inline float32x4_t mul (float32x4_t x, float32x4_t y) { return vec_mul (x, y); }
116+ #endif
117+
112118#if defined(__MMA__)
113119typedef vector unsigned char vec_t ;
114120typedef __vector_quad acc_t ;
@@ -162,6 +168,13 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
162168#endif
163169#endif
164170
171+ #if defined(__VXE__) || defined(__VXE2__)
172+ template <>
173+ inline float32x4_t madd (float32x4_t a, float32x4_t b, float32x4_t c) {
174+ return vec_madd (a, b, c);
175+ }
176+ #endif
177+
165178// //////////////////////////////////////////////////////////////////////////////////////////////////
166179// VECTORIZED HORIZONTAL SUM
167180
@@ -178,6 +191,13 @@ inline float hsum(float16x8_t x) {
178191}
179192#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
180193
194+ #if defined(__VXE__) || defined(__VXE2__)
195+ inline float hsum (float32x4_t x) {
196+ float32x4_t tmp = x + vec_reve (x);
197+ return tmp[0 ] + tmp[1 ];
198+ }
199+ #endif
200+
181201#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
182202inline float hsum (__m128 x) {
183203#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
@@ -227,6 +247,21 @@ template <> inline float32x4_t load(const ggml_fp16_t *p) {
227247#endif // _MSC_VER
228248#endif // __ARM_NEON
229249
250+ #if defined(__VXE__) || defined(__VXE2__)
251+ template <> inline float32x4_t load (const ggml_fp16_t * p) {
252+ float tmp[4 ];
253+
254+ for (int i = 0 ; i < 4 ; i++) {
255+ tmp[i] = GGML_FP16_TO_FP32 (p[i]);
256+ }
257+
258+ return vec_xl (0 , (const float *)(tmp));
259+ }
260+ template <> inline float32x4_t load (const float * p) {
261+ return vec_xl (0 , p);
262+ }
263+ #endif
264+
230265#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
231266template <> inline __m128 load (const float *p) {
232267 return _mm_loadu_ps (p);
@@ -3319,6 +3354,14 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
33193354 (const float *)B, ldb,
33203355 (float *)C, ldc};
33213356 return tb.matmul (m, n);
3357+ #elif defined(__VXE__) || defined(__VXE2__)
3358+ if (n < 4 )
3359+ return false ;
3360+ tinyBLAS<4 , float32x4_t , float32x4_t , float , float , float > tb{ params,
3361+ k, (const float *)A, lda,
3362+ (const float *)B, ldb,
3363+ (float *)C, ldc};
3364+ return tb.matmul (m, n);
33223365#elif defined(__MMA__)
33233366 if (k % 8 )
33243367 return false ;
@@ -3410,6 +3453,16 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
34103453 (float *)C, ldc};
34113454 return tb.matmul (m, n);
34123455 }
3456+ #elif defined(__VXE__) || defined(__VXE2__)
3457+ if (n < 4 )
3458+ return false ;
3459+ if (Btype == GGML_TYPE_F16) {
3460+ tinyBLAS<4 , float32x4_t , float32x4_t , ggml_fp16_t , ggml_fp16_t , float > tb{ params,
3461+ k, (const ggml_fp16_t *)A, lda,
3462+ (const ggml_fp16_t *)B, ldb,
3463+ (float *)C, ldc};
3464+ return tb.matmul (m, n);
3465+ }
34133466#endif
34143467 return false ;
34153468 }
0 commit comments