Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 54 additions & 1 deletion ggml/src/ggml-cpu/llamafile/sgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
#define NOINLINE __attribute__((__noinline__))
#endif

#if defined(__ARM_NEON) || defined(__AVX512F__)
#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__VXE__) || defined(__VXE2__)
#define VECTOR_REGISTERS 32
#else
#define VECTOR_REGISTERS 16
Expand Down Expand Up @@ -110,6 +110,12 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC

#if defined(__VXE__) || defined(__VXE2__)
inline float32x4_t add(float32x4_t x, float32x4_t y) { return vec_add(x, y); }
inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vec_sub(x, y); }
inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
#endif

#if defined(__MMA__)
typedef vector unsigned char vec_t;
typedef __vector_quad acc_t;
Expand Down Expand Up @@ -163,6 +169,13 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
#endif
#endif

#if defined(__VXE__) || defined(__VXE2__)
template <>
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
return vec_madd(a, b, c);
}
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED HORIZONTAL SUM

Expand All @@ -179,6 +192,13 @@ inline float hsum(float16x8_t x) {
}
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC

#if defined(__VXE__) || defined(__VXE2__)
inline float hsum(float32x4_t x) {
float32x4_t tmp = x + vec_reve(x);
return tmp[0] + tmp[1];
}
#endif

#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline float hsum(__m128 x) {
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
Expand Down Expand Up @@ -228,6 +248,21 @@ template <> inline float32x4_t load(const ggml_fp16_t *p) {
#endif // _MSC_VER
#endif // __ARM_NEON

#if defined(__VXE__) || defined(__VXE2__)
template <> inline float32x4_t load(const ggml_fp16_t * p) {
float tmp[4];

for (int i = 0; i < 4; i++) {
tmp[i] = GGML_FP16_TO_FP32(p[i]);
}

return vec_xl(0, (const float *)(tmp));
}
template <> inline float32x4_t load(const float * p) {
return vec_xl(0, p);
}
#endif

#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <> inline __m128 load(const float *p) {
return _mm_loadu_ps(p);
Expand Down Expand Up @@ -3323,6 +3358,14 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
(const float *)B, ldb,
(float *)C, ldc};
return tb.matmul(m, n);
#elif defined(__VXE__) || defined(__VXE2__)
if (n < 4)
return false;
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
k, (const float *)A, lda,
(const float *)B, ldb,
(float *)C, ldc};
return tb.matmul(m, n);
#elif defined(__MMA__)
if (k % 8)
return false;
Expand Down Expand Up @@ -3414,6 +3457,16 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
(float *)C, ldc};
return tb.matmul(m, n);
}
#elif defined(__VXE__) || defined(__VXE2__)
if (n < 4)
return false;
if (Btype == GGML_TYPE_F16) {
tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
k, (const ggml_fp16_t *)A, lda,
(const ggml_fp16_t *)B, ldb,
(float *)C, ldc};
return tb.matmul(m, n);
}
#endif
return false;
}
Expand Down
5 changes: 5 additions & 0 deletions ggml/src/ggml-cpu/llamafile/sgemm.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#pragma once
#include <stdint.h>
#include <stdbool.h>

#if defined(__VXE__) || defined(__VXE2__)
#include <vecintrin.h>
#endif

#ifdef __cplusplus
extern "C" {
#endif
Expand Down
Loading