Skip to content

Commit 0989c5e

Browse files
committed
refactor: auto detection for existing inference of same model & prompt
1 parent 2721685 commit 0989c5e

File tree

1 file changed

+31
-27
lines changed

1 file changed

+31
-27
lines changed

example/e-verify.cpp

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <iostream>
2424
#include <fstream>
2525
#include <string>
26+
#include <filesystem>
2627

2728
struct GenerationStepData {
2829
std::string tokenStr;
@@ -54,7 +55,7 @@ class Model {
5455
std::vector<GenerationStepData> steps;
5556
};
5657

57-
GenerationResult generate(std::string prompt, uint32_t maxTokens) {
58+
GenerationResult generate(std::string_view prompt, uint32_t maxTokens) {
5859
auto promptTokens = m_model->vocab().tokenize(prompt, true, true);
5960
return generate_impl(promptTokens, maxTokens);
6061
}
@@ -154,7 +155,7 @@ float normalizedEntropy(const ac::llama::TokenDataVector& data) {
154155
}
155156

156157

157-
std::vector<Model::GenerationResult> modelGeneration(Model& m1, Model& m2, const std::string& prompt, uint32_t maxTokens) {
158+
std::vector<Model::GenerationResult> modelGeneration(Model& m1, Model& m2, std::string_view prompt, uint32_t maxTokens) {
158159
auto res = m1.generate(prompt, maxTokens);
159160

160161
auto genPromptTokens = m2.tokenize(res.initalPrompt);
@@ -188,8 +189,7 @@ std::vector<Model::GenerationResult> modelGeneration(Model& m1, Model& m2, const
188189
}
189190

190191
// function to serialize the generation result in a file, so I can read it later
191-
void serialize(std::string& gguf, Model::GenerationResult& res) {
192-
std::string filename = "gen-res_" + gguf + "_" + res.initalPrompt + ".bin";
192+
void serialize(std::string_view filename, std::string_view gguf, Model::GenerationResult& res) {
193193
std::ofstream f(filename, std::ios::binary);
194194
if (!f) {
195195
std::cerr << "Error opening file for writing: " << filename << "\n";
@@ -198,7 +198,7 @@ void serialize(std::string& gguf, Model::GenerationResult& res) {
198198

199199
size_t ggufSize = gguf.size();
200200
f.write(reinterpret_cast<const char*>(&ggufSize), sizeof(ggufSize));
201-
f.write(gguf.c_str(), gguf.size());
201+
f.write(gguf.data(), gguf.size());
202202

203203
size_t initialPromptSize = res.initalPrompt.size();
204204
f.write(reinterpret_cast<const char*>(&initialPromptSize), sizeof(initialPromptSize));
@@ -223,7 +223,7 @@ void serialize(std::string& gguf, Model::GenerationResult& res) {
223223
}
224224
}
225225

226-
Model::GenerationResult deserialize(std::string filename) {
226+
Model::GenerationResult deserialize(std::string_view filename) {
227227
std::ifstream f(filename, std::ios::binary);
228228
if (!f) {
229229
std::cerr << "Error opening file for reading: " << filename << "\n";
@@ -277,7 +277,7 @@ Model::GenerationResult deserialize(std::string filename) {
277277
return res;
278278
}
279279

280-
void runTest(Model::GenerationResult& r1, Model::GenerationResult& r2) {
280+
void runCompare(Model::GenerationResult& r1, Model::GenerationResult& r2) {
281281
std::vector<float> jsdResults;
282282
std::vector<float> similarityResults;
283283
float totalWeightedDist = 0.0f;
@@ -362,39 +362,43 @@ int main() try {
362362

363363
// load model
364364
std::string tmpFolder = AC_TEST_DATA_LLAMA_DIR "/../../../tmp/";
365-
// std::string modelGguf = "Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf";
366-
std::string modelGguf = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
365+
std::string modelGguf = "Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf";
366+
// std::string modelGguf = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
367367
// std::string modelGguf = "BgGPT-Gemma-2-2B-IT-v1.0.Q8_0.gguf";
368368
// std::string modelGguf = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
369-
std::string modelGguf2 = "Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf";
370-
// std::string modelGguf2 = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
369+
// std::string modelGguf2 = "Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf";
370+
std::string modelGguf2 = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
371371

372-
std::string prompt = "The first person to";
372+
// std::string prompt = "The first person to ";
373+
std::string prompt = "Explain quantum physics in simple terms.";
373374
std::cout << "Prompt: " << prompt << "\n";
374375

375-
#if 0
376-
Model m1(tmpFolder + modelGguf, {});
377-
Model m2(tmpFolder + modelGguf2, {});
378-
379-
auto genRes = modelGeneration(m1, m2, prompt, 100);
380-
auto& r1 = genRes[0];
381-
auto& r2 = genRes[1];
382-
serialize(modelGguf, r1);
383-
serialize(modelGguf2, r2);
384-
#else
385376
std::string res1fn = "gen-res_" + modelGguf + "_" + prompt + ".bin";
386377
std::string res2fn = "gen-res_" + modelGguf2 + "_" + prompt + ".bin";
387-
388-
auto r1 = deserialize(res1fn);
389-
auto r2 = deserialize(res2fn);
390-
#endif
378+
bool shouldRunGenerate = !(std::filesystem::exists(res1fn) && std::filesystem::exists(res2fn));
379+
380+
Model::GenerationResult r1;
381+
Model::GenerationResult r2;
382+
if (shouldRunGenerate) {
383+
Model m1(tmpFolder + modelGguf, {});
384+
Model m2(tmpFolder + modelGguf2, {});
385+
386+
auto genRes = modelGeneration(m1, m2, prompt, 100);
387+
r1 = std::move(genRes[0]);
388+
r2 = std::move(genRes[1]);
389+
serialize(res1fn, modelGguf, r1);
390+
serialize(res2fn, modelGguf2, r2);
391+
} else {
392+
r1 = deserialize(res1fn);
393+
r2 = deserialize(res2fn);
394+
}
391395

392396
std::string result = prompt;
393397

394398
std::cout << "Models to compare:\n" << modelGguf << "\n" << modelGguf2 << "\n";
395399
std::cout << "Comparing...\n";
396400

397-
runTest(r1, r2);
401+
runCompare(r1, r2);
398402

399403
return 0;
400404
}

0 commit comments

Comments
 (0)