2121#include " ac-test-data-llama-dir.h"
2222
2323#include < iostream>
24+ #include < fstream>
2425#include < string>
2526
2627struct GenerationStepData {
@@ -153,6 +154,130 @@ float normalizedEntropy(const ac::llama::TokenDataVector& data) {
153154}
154155
155156
157+ std::vector<Model::GenerationResult> modelGeneration (Model& m1, Model& m2, const std::string& prompt, uint32_t maxTokens) {
158+ auto res = m1.generate (prompt, maxTokens);
159+
160+ auto genPromptTokens = m2.tokenize (res.initalPrompt );
161+
162+ Model::GenerationResult res2;
163+ for (size_t i = 0 ; i < res.steps .size (); i++) {
164+ auto & step = res.steps [i];
165+ if (i > 0 ) {
166+ if (m2.tokenExists (step.token )) {
167+ genPromptTokens.push_back (step.token );
168+ }
169+ else {
170+ std::cout << " Token not found in model 2: " << step.tokenStr << " \n " ;
171+ throw std::runtime_error (" Token not found in model 2" );
172+ }
173+ }
174+
175+ if (i == 0 ) {
176+ res2 = m2.generate (genPromptTokens, 0 );
177+ } else {
178+ Model::GenerationResult tempRes;
179+ std::vector<ac::llama::Token> token{step.token };
180+ tempRes = m2.generate (token, 0 );
181+ res2.steps .push_back (tempRes.steps [0 ]);
182+ }
183+ }
184+
185+ res2.result = res.result ;
186+
187+ return {res, res2};
188+ }
189+
190+ // function to serialize the generation result in a file, so I can read it later
191+ void serialize (std::string& gguf, Model::GenerationResult& res) {
192+ std::string filename = " gen-res_" + gguf + " _" + res.initalPrompt + " .bin" ;
193+ std::ofstream f (filename, std::ios::binary);
194+ if (!f) {
195+ std::cerr << " Error opening file for writing: " << filename << " \n " ;
196+ return ;
197+ }
198+
199+ size_t ggufSize = gguf.size ();
200+ f.write (reinterpret_cast <const char *>(&ggufSize), sizeof (ggufSize));
201+ f.write (gguf.c_str (), gguf.size ());
202+
203+ size_t initialPromptSize = res.initalPrompt .size ();
204+ f.write (reinterpret_cast <const char *>(&initialPromptSize), sizeof (initialPromptSize));
205+ f.write (res.initalPrompt .c_str (), res.initalPrompt .size ());
206+
207+ size_t resultSize = res.result .size ();
208+ f.write (reinterpret_cast <const char *>(&resultSize), sizeof (resultSize));
209+ f.write (res.result .c_str (), res.result .size ());
210+
211+ size_t stepsCount = res.steps .size ();
212+ f.write (reinterpret_cast <const char *>(&stepsCount), sizeof (stepsCount));
213+ for (const auto & step : res.steps ) {
214+ size_t tokenStrSize = step.tokenStr .size ();
215+ f.write (reinterpret_cast <const char *>(&tokenStrSize), sizeof (tokenStrSize));
216+
217+ f.write (step.tokenStr .c_str (), step.tokenStr .size ());
218+ f.write (reinterpret_cast <const char *>(&step.token ), sizeof (step.token ));
219+
220+ size_t tokenCount = step.data .size ();
221+ f.write (reinterpret_cast <const char *>(&tokenCount), sizeof (tokenCount));
222+ f.write (reinterpret_cast <const char *>(step.data .data ()), sizeof (ac::llama::TokenData) * tokenCount);
223+ }
224+ }
225+
226+ Model::GenerationResult deserialize (std::string filename) {
227+ std::ifstream f (filename, std::ios::binary);
228+ if (!f) {
229+ std::cerr << " Error opening file for reading: " << filename << " \n " ;
230+ return {};
231+ }
232+
233+ Model::GenerationResult res;
234+
235+ size_t ggufSize = 0 ;
236+ f.read (reinterpret_cast <char *>(&ggufSize), sizeof (ggufSize));
237+
238+ std::string gguf;
239+ gguf.resize (ggufSize);
240+ f.read (gguf.data (), ggufSize);
241+
242+ size_t initialPromptSize = 0 ;
243+ f.read (reinterpret_cast <char *>(&initialPromptSize), sizeof (initialPromptSize));
244+
245+ res.initalPrompt .resize (initialPromptSize);
246+ f.read (res.initalPrompt .data (), initialPromptSize);
247+
248+ size_t resultSize;
249+ f.read (reinterpret_cast <char *>(&resultSize), sizeof (resultSize));
250+
251+ res.result .resize (resultSize);
252+ f.read (res.result .data (), resultSize);
253+
254+ size_t stepsCount;
255+ f.read (reinterpret_cast <char *>(&stepsCount), sizeof (stepsCount));
256+ res.steps .reserve (stepsCount);
257+ for (size_t i = 0 ; i < stepsCount; ++i) {
258+ GenerationStepData step;
259+
260+ size_t tokenStrSize;
261+ f.read (reinterpret_cast <char *>(&tokenStrSize), sizeof (tokenStrSize));
262+
263+ step.tokenStr .resize (tokenStrSize);
264+ f.read (step.tokenStr .data (), tokenStrSize);
265+
266+ f.read (reinterpret_cast <char *>(&step.token ), sizeof (step.token ));
267+
268+ size_t tokenCount;
269+ f.read (reinterpret_cast <char *>(&tokenCount), sizeof (tokenCount));
270+
271+ step.data .resize (tokenCount);
272+ f.read (reinterpret_cast <char *>(step.data .data ()), sizeof (ac::llama::TokenData) * tokenCount);
273+
274+ res.steps .push_back (step);
275+ }
276+
277+ return res;
278+ }
279+
280+
156281
157282
158283int main () try {
@@ -173,103 +298,141 @@ int main() try {
173298 std::string modelGguf2 = " Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf" ;
174299 // std::string modelGguf2 = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
175300
301+ std::string prompt = " The first person to" ;
302+ std::cout << " Prompt: " << prompt << " \n " ;
303+
304+ #if 0
176305 Model m1(tmpFolder + modelGguf, {});
177306 Model m2(tmpFolder + modelGguf2, {});
178307
179- std::string prompt = " The first person to" ;
180- std::cout << " Prompt: " << prompt << " \n " ;
308+ auto genRes = modelGeneration(m1, m2, prompt, 100);
309+ auto& r1 = genRes[0];
310+ auto& r2 = genRes[1];
311+ serialize(modelGguf, r1);
312+ serialize(modelGguf2, r2);
313+ #else
314+ std::string res1fn = " gen-res_" + modelGguf + " _" + prompt + " .bin" ;
315+ std::string res2fn = " gen-res_" + modelGguf2 + " _" + prompt + " .bin" ;
316+
317+ auto r1 = deserialize (res1fn);
318+ auto r2 = deserialize (res2fn);
319+ #endif
181320
182321 std::string result = prompt;
183322
184323 std::cout << " Models to compare:\n " << modelGguf << " \n " << modelGguf2 << " \n " ;
185324 std::cout << " Comparing...\n " ;
186325
187326 std::vector<float > jsdResults;
327+ std::vector<float > similarityResults;
188328 for (int i = 0 ; i < 1 ; ++i) {
189-
190- auto res = m1.generate (prompt, 100 );
191- std::cout << " Model 1 generated: " << res.result << " \n " ;
192- std::string genPrompt = res.initalPrompt ;
193-
194- auto genPromptTokens = m2.tokenize (genPrompt);
195-
196329 float totalWeightedDist = 0 .0f ;
197330 float totalWeight = 0 .0f ;
198331
199- for (size_t i = 0 ; i < res.steps .size (); i++) {
200- auto & step = res.steps [i];
201- if (i > 0 ) {
202- if (m2.tokenExists (step.token )) {
203- genPromptTokens.push_back (step.token );
204- }
205- else {
206- // Instead of skipping, penalize fully
207- float fakeDist = 1 .0f ; // Maximum possible distance
208- float weight = 1 .0f ; // Assume maximum confidence since we can't know entropy
209- totalWeightedDist += weight * fakeDist;
210- totalWeight += weight;
211-
212- jsdResults.push_back (1 );
213-
214- std::cout << " Token not found in model 2: " << step.tokenStr << " \n " ;
215- continue ;
216- }
217- }
218332
219- Model::GenerationResult res2;
220- if (i == 0 ) {
221- res2 = m2.generate (genPromptTokens, 0 );
222- } else {
223- std::vector<ac::llama::Token> token{step.token };
224- res2 = m2.generate (token, 0 );
225- }
226-
227- assert (res2.steps .size () == 1 );
333+ // auto r1 = m1.generate(prompt, 100);
334+ // std::cout << "Model 1 generated: " << r1.result << "\n";
335+ // std::string genPrompt = r1.initalPrompt;
336+ // auto genPromptTokens = m2.tokenize(genPrompt);
337+
338+ // Model::GenerationResult r2;
339+ // for (size_t i = 0; i < r1.steps.size(); i++) {
340+ // auto& step = r1.steps[i];
341+ // if (i > 0) {
342+ // if (m2.tokenExists(step.token)) {
343+ // genPromptTokens.push_back(step.token);
344+ // }
345+ // else {
346+ // // Instead of skipping, penalize fully
347+ // float fakeDist = 1.0f; // Maximum possible distance
348+ // float weight = 1.0f; // Assume maximum confidence since we can't know entropy
349+ // totalWeightedDist += weight * fakeDist;
350+ // totalWeight += weight;
351+
352+ // jsdResults.push_back(1);
353+
354+ // similarityResults.push_back(0.0f);
355+
356+ // std::cout << "Token not found in model 2: " << step.tokenStr << "\n";
357+ // continue;
358+ // }
359+ // }
360+
361+ // if (i == 0) {
362+ // r2 = m2.generate(genPromptTokens, 0);
363+ // } else {
364+ // std::vector<ac::llama::Token> token{step.token};
365+ // Model::GenerationResult res2 = m2.generate(token, 0);
366+ // assert(res2.steps.size() == 1);
367+ // r2.steps.push_back(res2.steps[0]);
368+ // }
369+ // }
370+
371+ for (size_t i = 0 ; i < r1.steps .size (); i++) {
372+ auto & step1 = r1.steps [i];
373+ auto & step2 = r2.steps [i];
374+
375+ // Calculate distance
376+ float dist = ac::llama::LogitComparer::cosineDistance (step1.data , step2.data );
377+
378+ // Calculate weight based on normalized entropy
379+ float weight = normalizedEntropy (step1.data );
380+ totalWeightedDist += weight * dist;
381+ totalWeight += weight;
382+
383+ // Calculate JSD
384+ float jsd = ac::llama::LogitComparer::JSD (step1.data , step2.data );
385+ jsdResults.push_back (jsd);
386+
387+ // Calculate similarity
388+ float similarity = ac::llama::LogitComparer::logitSimilarity (step1.data , step2.data );
389+ similarityResults.push_back (similarity);
390+
391+ std::cout << " Token: " << step1.tokenStr
392+ << " , Weight: " << weight
393+ << " , JSD: " << jsd
394+ << " , Similarity: " << similarity
395+ << " , Distance: " << dist
396+ << " \n " ;
397+ }
228398
229- {
230- // Step 1: Compare logits
231- float dist = ac::llama::LogitComparer::cosineDistance (step.data , res2.steps [0 ].data );
232399
233- // Step 2: Calculate confidence weight
234- float entropy = normalizedEntropy (step.data );
235- float weight = 1 .0f - entropy; // high confidence = high weight
400+ {
401+ // Final step: Normalize
236402
237- // Step 3: Accumulate weighted distance
238- totalWeightedDist += weight * dist;
239- totalWeight += weight;
240- }
403+ // Score range | Interpretation
404+ // 0.0 | Perfect match (identical predictions)
405+ // 0.0001 - 0.001 | Practically indistinguishable
406+ // 0.001 - 0.01 | Very close, slight variation
407+ // 0.01 - 0.1 | Moderate variation, likely different versions/settings
408+ // 0.1 - 1.0 | Large differences, likely different models
409+ float finalScore = (totalWeight > 0 .0f ) ? (totalWeightedDist / totalWeight) : 0 .0f ;
410+ std::cout << " Final weighted distance score: " << finalScore << " \n " ;
411+ }
241412
242- {
243- float jsd = ac::llama::LogitComparer::JSD (step.data , res2.steps [0 ].data );
244- jsdResults.push_back (jsd);
413+ {
414+ // Final score interpretation
415+ // average JSD score
416+ // 0.0 | Perfect match (identical predictions)
417+ // 0.0001 - 0.001 | Practically indistinguishable
418+ // 0.001 - 0.01 | Moderate variation, likely different versions/settings
419+ // 0.01 - 0.1 | Large differences, likely different models
420+ float jsdSum = 0 .0f ;
421+ for (const auto & jsd : jsdResults) {
422+ jsdSum += jsd;
245423 }
246-
424+ float jsdAvg = jsdSum / jsdResults.size ();
425+ std::cout << " Average JSD score: " << jsdAvg << " \n " ;
247426 }
248427
249- // Final step: Normalize
250-
251- // Score range | Interpretation
252- // 0.0 | Perfect match (identical predictions)
253- // 0.0001 - 0.001 | Practically indistinguishable
254- // 0.001 - 0.01 | Very close, slight variation
255- // 0.01 - 0.1 | Moderate variation, likely different versions/settings
256- // 0.1 - 1.0 | Large differences, likely different models
257- float finalScore = (totalWeight > 0 .0f ) ? (totalWeightedDist / totalWeight) : 0 .0f ;
258- std::cout << " Final weighted distance score: " << finalScore << " \n " ;
259-
260- // Final score interpretation
261- // average JSD score
262- // 0.0 | Perfect match (identical predictions)
263- // 0.0001 - 0.001 | Practically indistinguishable
264- // 0.001 - 0.01 | Moderate variation, likely different versions/settings
265- // 0.01 - 0.1 | Large differences, likely different models
266- float jsdSum = 0 .0f ;
267- for (const auto & jsd : jsdResults) {
268- jsdSum += jsd;
428+ {
429+ float similaritySum = 0 .0f ;
430+ for (const auto & similarity : similarityResults) {
431+ similaritySum += similarity;
432+ }
433+ float similarityAvg = similaritySum / similarityResults.size ();
434+ std::cout << " Average similarity score: " << similarityAvg << " \n " ;
269435 }
270- float jsdAvg = jsdSum / jsdResults.size ();
271- std::cout << " Average JSD score: " << jsdAvg << " \n " ;
272-
273436 }
274437 std::cout << ' \n ' ;
275438
0 commit comments