@@ -39,6 +39,12 @@ class Model {
3939 m_instance.reset (new ac::llama::Instance (*m_model, {
4040 .ctxSize = 2048 ,
4141 }));
42+ m_session = &m_instance->startSession ({});
43+ m_session->setInitialPrompt ({}); // empty prompt
44+ }
45+
46+ ~Model () {
47+ m_instance->stopSession ();
4248 }
4349
4450 struct GenerationResult {
@@ -48,10 +54,27 @@ class Model {
4854 };
4955
5056 GenerationResult generate (std::string prompt, uint32_t maxTokens) {
51- m_session = &m_instance->startSession ({});
52-
5357 auto promptTokens = m_model->vocab ().tokenize (prompt, true , true );
54- m_session->setInitialPrompt (promptTokens);
58+ return generate_impl (promptTokens, maxTokens);
59+ }
60+
61+ GenerationResult generate (std::span<ac::llama::Token> prompt, uint32_t maxTokens) {
62+ return generate_impl (prompt, maxTokens);
63+ }
64+
65+ std::vector<ac::llama::Token> tokenize (std::string prompt) {
66+ return m_model->vocab ().tokenize (prompt, true , true );
67+ }
68+
69+ bool tokenExists (ac::llama::Token token) {
70+ return m_model->vocab ().nTokens () > token;
71+ }
72+
73+ private:
74+ GenerationResult generate_impl (std::span<ac::llama::Token> promptTokens, uint32_t maxTokens) {
75+ if (!promptTokens.empty ()) {
76+ m_session->pushPrompt (promptTokens, {});
77+ }
5578
5679 constexpr int32_t topK = 10 ;
5780 auto data = m_session->getSampledTokenData (topK);
@@ -85,13 +108,15 @@ class Model {
85108 });
86109 }
87110
88- m_instance->stopSession ();
89- m_session = nullptr ;
111+ std::string initialPrompt = " " ;
112+ for (size_t i = 0 ; i < promptTokens.size (); i++){
113+ initialPrompt += m_model->vocab ().tokenToString (promptTokens[i], false );
114+ }
90115
91116 return {
92- .initalPrompt = prompt ,
93- .result = result,
94- .steps = genSteps
117+ .initalPrompt = std::move (initialPrompt) ,
118+ .result = std::move ( result) ,
119+ .steps = std::move ( genSteps)
95120 };
96121 }
97122
@@ -101,6 +126,35 @@ class Model {
101126 ac::llama::Session* m_session;
102127};
103128
129+ // -- Helper function to compute normalized entropy --
130+ float normalizedEntropy (const ac::llama::TokenDataVector& data) {
131+ std::vector<float > probs (data.size ());
132+ float sum = 0 .0f ;
133+
134+ // Calculate softmax probabilities
135+ for (auto & val : data) {
136+ sum += std::exp (val.logit );
137+ }
138+ for (size_t i = 0 ; i < data.size (); ++i) {
139+ probs[i] = std::exp (data[i].logit ) / sum;
140+ }
141+
142+ // Calculate entropy
143+ float entropy = 0 .0f ;
144+ for (float p : probs) {
145+ if (p > 0 .0f ) {
146+ entropy -= p * std::log (p);
147+ }
148+ }
149+
150+ // Normalize entropy by maximum possible entropy (log(number of classes))
151+ float maxEntropy = std::log (float (probs.size ()));
152+ return entropy / maxEntropy;
153+ }
154+
155+
156+
157+
104158int main () try {
105159 ac::jalog::Instance jl;
106160 jl.setup ().add <ac::jalog::sinks::ColorSink>();
@@ -112,7 +166,10 @@ int main() try {
112166
113167 // load model
114168 std::string tmpFolder = AC_TEST_DATA_LLAMA_DIR " /../../../tmp/" ;
115- std::string modelGguf = " Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf" ;
169+ // std::string modelGguf = "Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf";
170+ std::string modelGguf = " Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf" ;
171+ // std::string modelGguf = "BgGPT-Gemma-2-2B-IT-v1.0.Q8_0.gguf";
172+ // std::string modelGguf = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
116173 std::string modelGguf2 = " Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf" ;
117174
118175 Model m1 (tmpFolder + modelGguf, {});
@@ -128,21 +185,66 @@ int main() try {
128185
129186 for (int i = 0 ; i < 1 ; ++i) {
130187
131- auto res = m1.generate (prompt, 100 );
188+ auto res = m1.generate (prompt, 1000 );
132189 std::cout << " Model 1 generated: " << res.result << " \n " ;
133190 std::string genPrompt = res.initalPrompt ;
191+
192+ auto genPromptTokens = m2.tokenize (genPrompt);
193+
194+ float totalWeightedDist = 0 .0f ;
195+ float totalWeight = 0 .0f ;
196+
134197 for (size_t i = 0 ; i < res.steps .size (); i++) {
135198 auto & step = res.steps [i];
136199 if (i > 0 ) {
137- genPrompt += step.tokenStr ;
200+ if (m2.tokenExists (step.token )) {
201+ genPromptTokens.push_back (step.token );
202+ }
203+ else {
204+ // Instead of skipping, penalize fully
205+ float fakeDist = 1 .0f ; // Maximum possible distance
206+ float weight = 1 .0f ; // Assume maximum confidence since we can't know entropy
207+ totalWeightedDist += weight * fakeDist;
208+ totalWeight += weight;
209+
210+ std::cout << " Token not found in model 2: " << step.tokenStr << " \n " ;
211+ continue ;
212+ }
138213 }
139- auto res2 = m2.generate (genPrompt, 0 );
140- assert (res2.steps .size () == 1 );
141214
142- if (ac::llama::LogitComparer::compare (step.data , res2.steps [0 ].data )) {
143- std::cout << " Models are the same. Generated str by now:\n " << genPrompt << " \n\n " ;
215+ Model::GenerationResult res2;
216+ if (i == 0 ) {
217+ res2 = m2.generate (genPromptTokens, 0 );
218+ } else {
219+ std::vector<ac::llama::Token> token{step.token };
220+ res2 = m2.generate (token, 0 );
144221 }
222+
223+ assert (res2.steps .size () == 1 );
224+
225+ // Step 1: Compare logits
226+ float dist = ac::llama::LogitComparer::cosineDistance (step.data , res2.steps [0 ].data );
227+
228+ // Step 2: Calculate confidence weight
229+ float entropy = normalizedEntropy (step.data );
230+ float weight = 1 .0f - entropy; // high confidence = high weight
231+
232+ // Step 3: Accumulate weighted distance
233+ totalWeightedDist += weight * dist;
234+ totalWeight += weight;
145235 }
236+
237+ // Final step: Normalize
238+
239+ // Score range | Interpretation
240+ // 0.0 | Perfect match (identical predictions)
241+ // 0.0001 - 0.001 | Practically indistinguishable
242+ // 0.001 - 0.01 | Very close, slight variation
243+ // 0.01 - 0.1 | Moderate variation, likely different versions/settings
244+ // 0.1 - 1.0 | Large differences, likely different models
245+ float finalScore = (totalWeight > 0 .0f ) ? (totalWeightedDist / totalWeight) : 0 .0f ;
246+ std::cout << " Final weighted distance score: " << finalScore << " \n " ;
247+
146248 }
147249 std::cout << ' \n ' ;
148250
0 commit comments