1- // #include <iostream>
2- #include < cstdint>
31#include < cassert>
42#include < algorithm>
53
4+ #define GGML_COMMON_DECL_CPP
5+ #include " ggml-common.h"
6+ #include " ggml.h"
7+
68#include " ggml-fp8.h"
79
810/*
@@ -20,6 +22,11 @@ cmake --build build --config Release
2022./build/bin/llama-cli -c 1024 -m ~/LLM/Mistral-Nemo-Instruct-2407.E3M4_Q.gguf -p "[INST]bonjour a tu un nom. je ne sais pas comment t'appeler. Si tu n'en as pas je peux t'appeler TINTIN[/INST]" -s 42
2123./build/bin/llama-perplexity --kl-divergence-base ~/LLM/Mistral-Nemo-Instruct-2407.BF16.kld --kl-divergence -s 31337 -m ~/LLM/Mistral-Nemo-Instruct-2407.E3M4_Q.gguf
2224
25+ # la CI local:
26+ rm -rf tmp
27+ mkdir tmp
28+ bash ./ci/run.sh ./tmp/results ./tmp/mnt
29+
2330*/
2431
2532template <int N> constexpr float EXP2 () {
@@ -45,7 +52,9 @@ struct FP8 {
4552 static constexpr float MIN () { return EXP2<-M>()*EXP2<2 -EXP_I2<_E-1 >()>(); }
4653 // =============================================
4754
55+ #ifdef GGML_USE_OPENMP_SIMD
4856 #pragma omp declare simd
57+ #endif
4958 void operator =(float value) {
5059 union {
5160 float f;
@@ -67,7 +76,9 @@ struct FP8 {
6776 }
6877 }
6978
79+ #ifdef GGML_USE_OPENMP_SIMD
7080 #pragma omp declare simd
81+ #endif
7182 operator float () const {
7283 union {
7384 float f;
@@ -84,23 +95,21 @@ struct FP8 {
8495 }
8596};
8697
87- // block_e4m3_q
88- // typedef struct {
89- // float d; // delta
90- // ggml_e4m3 qs[QK_K];
91- // } block_e4m3_q;
92-
9398template <int E>
9499static inline void conv (const FP8<E>* x, float * y, int64_t size) {
100+ #ifdef GGML_USE_OPENMP_SIMD
95101 #pragma omp simd
102+ #endif
96103 for (int64_t i=0 ; i<size; i++) {
97104 y[i] = (float ) x[i];
98105 }
99106}
100107
101108template <int E>
102109static inline void conv (const float * x, FP8<E>* y, int64_t size) {
110+ #ifdef GGML_USE_OPENMP_SIMD
103111 #pragma omp simd
112+ #endif
104113 for (int64_t i=0 ; i<size; i++) {
105114 y[i] = x[i];
106115 }
@@ -109,7 +118,9 @@ static inline void conv(const float* x, FP8<E>* y, int64_t size) {
109118template <int E>
110119static inline float dot (const FP8<E>* x, const float * y, int64_t size) {
111120 float z = 0 ;
121+ #ifdef GGML_USE_OPENMP_SIMD
112122 #pragma omp simd reduction(+:z)
123+ #endif
113124 for (int64_t i=0 ; i<size; i++) {
114125 z += ((float )x[i])*y[i];
115126 }
@@ -126,7 +137,9 @@ template <int E, int QK>
126137static inline void conv (const bloc_fp8<E, QK>* x, float * y, int64_t size) {
127138 const auto qk_size = size / QK;
128139 for (int64_t q=0 ; q<qk_size; ++q) {
140+ #ifdef GGML_USE_OPENMP_SIMD
129141 #pragma omp simd
142+ #endif
130143 for (int64_t i=0 ; i<QK; i++) {
131144 y[q*QK+i] = ((float ) x[q].qs [i])*(x[q]).d ;
132145 }
@@ -138,13 +151,18 @@ static inline void conv(const float* x, bloc_fp8<E, QK>* y, int64_t size) {
138151 const auto qk_size = size / QK;
139152 for (int64_t q=0 ; q<qk_size; ++q) {
140153 float m = 0 ;
154+ // @ voir si c'est lui qui pose probleme et si c'est sur toutes les target
155+ #ifdef GGML_USE_OPENMP_SIMD
141156 #pragma omp simd reduction(max:m)
157+ #endif
142158 for (int64_t i=0 ; i<QK; i++) {
143159 m = std::max (std::abs (x[q*QK+i]),m);
144160 }
145161 const float D = FP8<E>::MAX ()/m;
146162 y[q].d = m/FP8<E>::MAX ();
163+ #ifdef GGML_USE_OPENMP_SIMD
147164 #pragma omp simd
165+ #endif
148166 for (int64_t i=0 ; i<QK; i++) {
149167 y[q].qs [i] = x[q*QK+i]*D;
150168 }
@@ -157,7 +175,9 @@ static inline float dot(const bloc_fp8<E, QK>* x, const float* y, int64_t size)
157175 const auto qk_size = size / QK;
158176 for (int64_t q=0 ; q<qk_size; ++q) {
159177 float z0 = 0 ;
178+ #ifdef GGML_USE_OPENMP_SIMD
160179 #pragma omp simd reduction(+:z0)
180+ #endif
161181 for (int64_t i=0 ; i<QK; i++) {
162182 z0 += ((float )x[q].qs [i])*y[q*QK+i];
163183 }
@@ -192,29 +212,29 @@ void ggml_fp32_to_e4m3_row_ref(const float * GGML_RESTRICT x, ggml_e4m3_t * GGML
192212}
193213
194214void dequantize_row_e4m3_q (const block_e4m3_q * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
195- assert (k % FP8_QK == 0 );
196- conv (reinterpret_cast <const bloc_fp8<4 , FP8_QK >*>(x), y, k);
215+ assert (k % QK_K == 0 );
216+ conv (reinterpret_cast <const bloc_fp8<4 , QK_K >*>(x), y, k);
197217}
198218void quantize_row_e4m3_q (const float * GGML_RESTRICT x, block_e4m3_q * GGML_RESTRICT y, int64_t k) {
199- assert (k % FP8_QK == 0 );
200- conv (x, reinterpret_cast <bloc_fp8<4 , FP8_QK >*>(y), k);
219+ assert (k % QK_K == 0 );
220+ conv (x, reinterpret_cast <bloc_fp8<4 , QK_K >*>(y), k);
201221}
202222void quantize_row_e4m3_q_ref (const float * GGML_RESTRICT x, block_e4m3_q * GGML_RESTRICT y, int64_t k) {
203- assert (k % FP8_QK == 0 );
204- conv (x, reinterpret_cast <bloc_fp8<4 , FP8_QK >*>(y), k);
223+ assert (k % QK_K == 0 );
224+ conv (x, reinterpret_cast <bloc_fp8<4 , QK_K >*>(y), k);
205225}
206226
207227void dequantize_row_e3m4_q (const block_e3m4_q * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
208- assert (k % FP8_QK == 0 );
209- conv (reinterpret_cast <const bloc_fp8<3 , FP8_QK >*>(x), y, k);
228+ assert (k % QK_K == 0 );
229+ conv (reinterpret_cast <const bloc_fp8<3 , QK_K >*>(x), y, k);
210230}
211231void quantize_row_e3m4_q (const float * GGML_RESTRICT x, block_e3m4_q * GGML_RESTRICT y, int64_t k) {
212- assert (k % FP8_QK == 0 );
213- conv (x, reinterpret_cast <bloc_fp8<3 , FP8_QK >*>(y), k);
232+ assert (k % QK_K == 0 );
233+ conv (x, reinterpret_cast <bloc_fp8<3 , QK_K >*>(y), k);
214234}
215235void quantize_row_e3m4_q_ref (const float * GGML_RESTRICT x, block_e3m4_q * GGML_RESTRICT y, int64_t k) {
216- assert (k % FP8_QK == 0 );
217- conv (x, reinterpret_cast <bloc_fp8<3 , FP8_QK >*>(y), k);
236+ assert (k % QK_K == 0 );
237+ conv (x, reinterpret_cast <bloc_fp8<3 , QK_K >*>(y), k);
218238}
219239
220240// the dot product for FP8 weight
@@ -242,7 +262,7 @@ void ggml_vec_dot_e4m3_q(int n, float * GGML_RESTRICT s, size_t bs, const block_
242262 GGML_UNUSED (bx);
243263 GGML_UNUSED (by);
244264 GGML_UNUSED (bs);
245- *s = dot (reinterpret_cast <const bloc_fp8<4 , FP8_QK >*>(vx), vy, n);
265+ *s = dot (reinterpret_cast <const bloc_fp8<4 , QK_K >*>(vx), vy, n);
246266}
247267
248268void ggml_vec_dot_e3m4_q (int n, float * GGML_RESTRICT s, size_t bs, const block_e3m4_q * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT vy, size_t by, int nrc) {
@@ -251,5 +271,5 @@ void ggml_vec_dot_e3m4_q(int n, float * GGML_RESTRICT s, size_t bs, const block_
251271 GGML_UNUSED (bx);
252272 GGML_UNUSED (by);
253273 GGML_UNUSED (bs);
254- *s = dot (reinterpret_cast <const bloc_fp8<3 , FP8_QK >*>(vx), vy, n);
274+ *s = dot (reinterpret_cast <const bloc_fp8<3 , QK_K >*>(vx), vy, n);
255275}
0 commit comments