Skip to content

Commit 6228b94

Browse files
committed
feat: implement serialize and deserialize all inference info to save time
1 parent d61e0cd commit 6228b94

File tree

1 file changed

+237
-74
lines changed

1 file changed

+237
-74
lines changed

example/e-verify.cpp

Lines changed: 237 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "ac-test-data-llama-dir.h"
2222

2323
#include <iostream>
24+
#include <fstream>
2425
#include <string>
2526

2627
struct GenerationStepData {
@@ -153,6 +154,130 @@ float normalizedEntropy(const ac::llama::TokenDataVector& data) {
153154
}
154155

155156

157+
std::vector<Model::GenerationResult> modelGeneration(Model& m1, Model& m2, const std::string& prompt, uint32_t maxTokens) {
158+
auto res = m1.generate(prompt, maxTokens);
159+
160+
auto genPromptTokens = m2.tokenize(res.initalPrompt);
161+
162+
Model::GenerationResult res2;
163+
for (size_t i = 0; i < res.steps.size(); i++) {
164+
auto& step = res.steps[i];
165+
if (i > 0) {
166+
if (m2.tokenExists(step.token)) {
167+
genPromptTokens.push_back(step.token);
168+
}
169+
else {
170+
std::cout << "Token not found in model 2: " << step.tokenStr << "\n";
171+
throw std::runtime_error("Token not found in model 2");
172+
}
173+
}
174+
175+
if (i == 0) {
176+
res2 = m2.generate(genPromptTokens, 0);
177+
} else {
178+
Model::GenerationResult tempRes;
179+
std::vector<ac::llama::Token> token{step.token};
180+
tempRes = m2.generate(token, 0);
181+
res2.steps.push_back(tempRes.steps[0]);
182+
}
183+
}
184+
185+
res2.result = res.result;
186+
187+
return {res, res2};
188+
}
189+
190+
// function to serialize the generation result in a file, so I can read it later
191+
void serialize(std::string& gguf, Model::GenerationResult& res) {
192+
std::string filename = "gen-res_" + gguf + "_" + res.initalPrompt + ".bin";
193+
std::ofstream f(filename, std::ios::binary);
194+
if (!f) {
195+
std::cerr << "Error opening file for writing: " << filename << "\n";
196+
return;
197+
}
198+
199+
size_t ggufSize = gguf.size();
200+
f.write(reinterpret_cast<const char*>(&ggufSize), sizeof(ggufSize));
201+
f.write(gguf.c_str(), gguf.size());
202+
203+
size_t initialPromptSize = res.initalPrompt.size();
204+
f.write(reinterpret_cast<const char*>(&initialPromptSize), sizeof(initialPromptSize));
205+
f.write(res.initalPrompt.c_str(), res.initalPrompt.size());
206+
207+
size_t resultSize = res.result.size();
208+
f.write(reinterpret_cast<const char*>(&resultSize), sizeof(resultSize));
209+
f.write(res.result.c_str(), res.result.size());
210+
211+
size_t stepsCount = res.steps.size();
212+
f.write(reinterpret_cast<const char*>(&stepsCount), sizeof(stepsCount));
213+
for (const auto& step : res.steps) {
214+
size_t tokenStrSize = step.tokenStr.size();
215+
f.write(reinterpret_cast<const char*>(&tokenStrSize), sizeof(tokenStrSize));
216+
217+
f.write(step.tokenStr.c_str(), step.tokenStr.size());
218+
f.write(reinterpret_cast<const char*>(&step.token), sizeof(step.token));
219+
220+
size_t tokenCount = step.data.size();
221+
f.write(reinterpret_cast<const char*>(&tokenCount), sizeof(tokenCount));
222+
f.write(reinterpret_cast<const char*>(step.data.data()), sizeof(ac::llama::TokenData) * tokenCount);
223+
}
224+
}
225+
226+
Model::GenerationResult deserialize(std::string filename) {
227+
std::ifstream f(filename, std::ios::binary);
228+
if (!f) {
229+
std::cerr << "Error opening file for reading: " << filename << "\n";
230+
return {};
231+
}
232+
233+
Model::GenerationResult res;
234+
235+
size_t ggufSize = 0;
236+
f.read(reinterpret_cast<char*>(&ggufSize), sizeof(ggufSize));
237+
238+
std::string gguf;
239+
gguf.resize(ggufSize);
240+
f.read(gguf.data(), ggufSize);
241+
242+
size_t initialPromptSize = 0;
243+
f.read(reinterpret_cast<char*>(&initialPromptSize), sizeof(initialPromptSize));
244+
245+
res.initalPrompt.resize(initialPromptSize);
246+
f.read(res.initalPrompt.data(), initialPromptSize);
247+
248+
size_t resultSize;
249+
f.read(reinterpret_cast<char*>(&resultSize), sizeof(resultSize));
250+
251+
res.result.resize(resultSize);
252+
f.read(res.result.data(), resultSize);
253+
254+
size_t stepsCount;
255+
f.read(reinterpret_cast<char*>(&stepsCount), sizeof(stepsCount));
256+
res.steps.reserve(stepsCount);
257+
for (size_t i = 0; i < stepsCount; ++i) {
258+
GenerationStepData step;
259+
260+
size_t tokenStrSize;
261+
f.read(reinterpret_cast<char*>(&tokenStrSize), sizeof(tokenStrSize));
262+
263+
step.tokenStr.resize(tokenStrSize);
264+
f.read(step.tokenStr.data(), tokenStrSize);
265+
266+
f.read(reinterpret_cast<char*>(&step.token), sizeof(step.token));
267+
268+
size_t tokenCount;
269+
f.read(reinterpret_cast<char*>(&tokenCount), sizeof(tokenCount));
270+
271+
step.data.resize(tokenCount);
272+
f.read(reinterpret_cast<char*>(step.data.data()), sizeof(ac::llama::TokenData) * tokenCount);
273+
274+
res.steps.push_back(step);
275+
}
276+
277+
return res;
278+
}
279+
280+
156281

157282

158283
int main() try {
@@ -173,103 +298,141 @@ int main() try {
173298
std::string modelGguf2 = "Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf";
174299
// std::string modelGguf2 = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
175300

301+
std::string prompt = "The first person to";
302+
std::cout << "Prompt: " << prompt << "\n";
303+
304+
#if 0
176305
Model m1(tmpFolder + modelGguf, {});
177306
Model m2(tmpFolder + modelGguf2, {});
178307

179-
std::string prompt = "The first person to";
180-
std::cout << "Prompt: " << prompt << "\n";
308+
auto genRes = modelGeneration(m1, m2, prompt, 100);
309+
auto& r1 = genRes[0];
310+
auto& r2 = genRes[1];
311+
serialize(modelGguf, r1);
312+
serialize(modelGguf2, r2);
313+
#else
314+
std::string res1fn = "gen-res_" + modelGguf + "_" + prompt + ".bin";
315+
std::string res2fn = "gen-res_" + modelGguf2 + "_" + prompt + ".bin";
316+
317+
auto r1 = deserialize(res1fn);
318+
auto r2 = deserialize(res2fn);
319+
#endif
181320

182321
std::string result = prompt;
183322

184323
std::cout << "Models to compare:\n" << modelGguf << "\n" << modelGguf2 << "\n";
185324
std::cout << "Comparing...\n";
186325

187326
std::vector<float> jsdResults;
327+
std::vector<float> similarityResults;
188328
for (int i = 0; i < 1; ++i) {
189-
190-
auto res = m1.generate(prompt, 100);
191-
std::cout << "Model 1 generated: " << res.result << "\n";
192-
std::string genPrompt = res.initalPrompt;
193-
194-
auto genPromptTokens = m2.tokenize(genPrompt);
195-
196329
float totalWeightedDist = 0.0f;
197330
float totalWeight = 0.0f;
198331

199-
for (size_t i = 0; i < res.steps.size(); i++) {
200-
auto& step = res.steps[i];
201-
if (i > 0) {
202-
if (m2.tokenExists(step.token)) {
203-
genPromptTokens.push_back(step.token);
204-
}
205-
else {
206-
// Instead of skipping, penalize fully
207-
float fakeDist = 1.0f; // Maximum possible distance
208-
float weight = 1.0f; // Assume maximum confidence since we can't know entropy
209-
totalWeightedDist += weight * fakeDist;
210-
totalWeight += weight;
211-
212-
jsdResults.push_back(1);
213-
214-
std::cout << "Token not found in model 2: " << step.tokenStr << "\n";
215-
continue;
216-
}
217-
}
218332

219-
Model::GenerationResult res2;
220-
if (i == 0) {
221-
res2 = m2.generate(genPromptTokens, 0);
222-
} else {
223-
std::vector<ac::llama::Token> token{step.token};
224-
res2 = m2.generate(token, 0);
225-
}
226-
227-
assert(res2.steps.size() == 1);
333+
// auto r1 = m1.generate(prompt, 100);
334+
// std::cout << "Model 1 generated: " << r1.result << "\n";
335+
// std::string genPrompt = r1.initalPrompt;
336+
// auto genPromptTokens = m2.tokenize(genPrompt);
337+
338+
// Model::GenerationResult r2;
339+
// for (size_t i = 0; i < r1.steps.size(); i++) {
340+
// auto& step = r1.steps[i];
341+
// if (i > 0) {
342+
// if (m2.tokenExists(step.token)) {
343+
// genPromptTokens.push_back(step.token);
344+
// }
345+
// else {
346+
// // Instead of skipping, penalize fully
347+
// float fakeDist = 1.0f; // Maximum possible distance
348+
// float weight = 1.0f; // Assume maximum confidence since we can't know entropy
349+
// totalWeightedDist += weight * fakeDist;
350+
// totalWeight += weight;
351+
352+
// jsdResults.push_back(1);
353+
354+
// similarityResults.push_back(0.0f);
355+
356+
// std::cout << "Token not found in model 2: " << step.tokenStr << "\n";
357+
// continue;
358+
// }
359+
// }
360+
361+
// if (i == 0) {
362+
// r2 = m2.generate(genPromptTokens, 0);
363+
// } else {
364+
// std::vector<ac::llama::Token> token{step.token};
365+
// Model::GenerationResult res2 = m2.generate(token, 0);
366+
// assert(res2.steps.size() == 1);
367+
// r2.steps.push_back(res2.steps[0]);
368+
// }
369+
// }
370+
371+
for (size_t i = 0; i < r1.steps.size(); i++) {
372+
auto& step1 = r1.steps[i];
373+
auto& step2 = r2.steps[i];
374+
375+
// Calculate distance
376+
float dist = ac::llama::LogitComparer::cosineDistance(step1.data, step2.data);
377+
378+
// Calculate weight based on normalized entropy
379+
float weight = normalizedEntropy(step1.data);
380+
totalWeightedDist += weight * dist;
381+
totalWeight += weight;
382+
383+
// Calculate JSD
384+
float jsd = ac::llama::LogitComparer::JSD(step1.data, step2.data);
385+
jsdResults.push_back(jsd);
386+
387+
// Calculate similarity
388+
float similarity = ac::llama::LogitComparer::logitSimilarity(step1.data, step2.data);
389+
similarityResults.push_back(similarity);
390+
391+
std::cout << "Token: " << step1.tokenStr
392+
<< ", Weight: " << weight
393+
<< ", JSD: " << jsd
394+
<< ", Similarity: " << similarity
395+
<< ", Distance: " << dist
396+
<< "\n";
397+
}
228398

229-
{
230-
// Step 1: Compare logits
231-
float dist = ac::llama::LogitComparer::cosineDistance(step.data, res2.steps[0].data);
232399

233-
// Step 2: Calculate confidence weight
234-
float entropy = normalizedEntropy(step.data);
235-
float weight = 1.0f - entropy; // high confidence = high weight
400+
{
401+
// Final step: Normalize
236402

237-
// Step 3: Accumulate weighted distance
238-
totalWeightedDist += weight * dist;
239-
totalWeight += weight;
240-
}
403+
// Score range | Interpretation
404+
// 0.0 | Perfect match (identical predictions)
405+
// 0.0001 - 0.001 | Practically indistinguishable
406+
// 0.001 - 0.01 | Very close, slight variation
407+
// 0.01 - 0.1 | Moderate variation, likely different versions/settings
408+
// 0.1 - 1.0 | Large differences, likely different models
409+
float finalScore = (totalWeight > 0.0f) ? (totalWeightedDist / totalWeight) : 0.0f;
410+
std::cout << "Final weighted distance score: " << finalScore << "\n";
411+
}
241412

242-
{
243-
float jsd = ac::llama::LogitComparer::JSD(step.data, res2.steps[0].data);
244-
jsdResults.push_back(jsd);
413+
{
414+
// Final score interpretation
415+
// average JSD score
416+
// 0.0 | Perfect match (identical predictions)
417+
// 0.0001 - 0.001 | Practically indistinguishable
418+
// 0.001 - 0.01 | Moderate variation, likely different versions/settings
419+
// 0.01 - 0.1 | Large differences, likely different models
420+
float jsdSum = 0.0f;
421+
for (const auto& jsd : jsdResults) {
422+
jsdSum += jsd;
245423
}
246-
424+
float jsdAvg = jsdSum / jsdResults.size();
425+
std::cout << "Average JSD score: " << jsdAvg << "\n";
247426
}
248427

249-
// Final step: Normalize
250-
251-
// Score range | Interpretation
252-
// 0.0 | Perfect match (identical predictions)
253-
// 0.0001 - 0.001 | Practically indistinguishable
254-
// 0.001 - 0.01 | Very close, slight variation
255-
// 0.01 - 0.1 | Moderate variation, likely different versions/settings
256-
// 0.1 - 1.0 | Large differences, likely different models
257-
float finalScore = (totalWeight > 0.0f) ? (totalWeightedDist / totalWeight) : 0.0f;
258-
std::cout << "Final weighted distance score: " << finalScore << "\n";
259-
260-
// Final score interpretation
261-
// average JSD score
262-
// 0.0 | Perfect match (identical predictions)
263-
// 0.0001 - 0.001 | Practically indistinguishable
264-
// 0.001 - 0.01 | Moderate variation, likely different versions/settings
265-
// 0.01 - 0.1 | Large differences, likely different models
266-
float jsdSum = 0.0f;
267-
for (const auto& jsd : jsdResults) {
268-
jsdSum += jsd;
428+
{
429+
float similaritySum = 0.0f;
430+
for (const auto& similarity : similarityResults) {
431+
similaritySum += similarity;
432+
}
433+
float similarityAvg = similaritySum / similarityResults.size();
434+
std::cout << "Average similarity score: " << similarityAvg << "\n";
269435
}
270-
float jsdAvg = jsdSum / jsdResults.size();
271-
std::cout << "Average JSD score: " << jsdAvg << "\n";
272-
273436
}
274437
std::cout << '\n';
275438

0 commit comments

Comments
 (0)