2323#include < iostream>
2424#include < fstream>
2525#include < string>
26+ #include < filesystem>
2627
2728struct GenerationStepData {
2829 std::string tokenStr;
@@ -54,7 +55,7 @@ class Model {
5455 std::vector<GenerationStepData> steps;
5556 };
5657
57- GenerationResult generate (std::string prompt, uint32_t maxTokens) {
58+ GenerationResult generate (std::string_view prompt, uint32_t maxTokens) {
5859 auto promptTokens = m_model->vocab ().tokenize (prompt, true , true );
5960 return generate_impl (promptTokens, maxTokens);
6061 }
@@ -154,7 +155,7 @@ float normalizedEntropy(const ac::llama::TokenDataVector& data) {
154155}
155156
156157
157- std::vector<Model::GenerationResult> modelGeneration (Model& m1, Model& m2, const std::string& prompt, uint32_t maxTokens) {
158+ std::vector<Model::GenerationResult> modelGeneration (Model& m1, Model& m2, std::string_view prompt, uint32_t maxTokens) {
158159 auto res = m1.generate (prompt, maxTokens);
159160
160161 auto genPromptTokens = m2.tokenize (res.initalPrompt );
@@ -188,8 +189,7 @@ std::vector<Model::GenerationResult> modelGeneration(Model& m1, Model& m2, const
188189}
189190
190191// function to serialize the generation result in a file, so I can read it later
191- void serialize (std::string& gguf, Model::GenerationResult& res) {
192- std::string filename = " gen-res_" + gguf + " _" + res.initalPrompt + " .bin" ;
192+ void serialize (std::string_view filename, std::string_view gguf, Model::GenerationResult& res) {
193193 std::ofstream f (filename, std::ios::binary);
194194 if (!f) {
195195 std::cerr << " Error opening file for writing: " << filename << " \n " ;
@@ -198,7 +198,7 @@ void serialize(std::string& gguf, Model::GenerationResult& res) {
198198
199199 size_t ggufSize = gguf.size ();
200200 f.write (reinterpret_cast <const char *>(&ggufSize), sizeof (ggufSize));
201- f.write (gguf.c_str (), gguf.size ());
201+ f.write (gguf.data (), gguf.size ());
202202
203203 size_t initialPromptSize = res.initalPrompt .size ();
204204 f.write (reinterpret_cast <const char *>(&initialPromptSize), sizeof (initialPromptSize));
@@ -223,7 +223,7 @@ void serialize(std::string& gguf, Model::GenerationResult& res) {
223223 }
224224}
225225
226- Model::GenerationResult deserialize (std::string filename) {
226+ Model::GenerationResult deserialize (std::string_view filename) {
227227 std::ifstream f (filename, std::ios::binary);
228228 if (!f) {
229229 std::cerr << " Error opening file for reading: " << filename << " \n " ;
@@ -277,7 +277,7 @@ Model::GenerationResult deserialize(std::string filename) {
277277 return res;
278278}
279279
280- void runTest (Model::GenerationResult& r1, Model::GenerationResult& r2) {
280+ void runCompare (Model::GenerationResult& r1, Model::GenerationResult& r2) {
281281 std::vector<float > jsdResults;
282282 std::vector<float > similarityResults;
283283 float totalWeightedDist = 0 .0f ;
@@ -362,39 +362,43 @@ int main() try {
362362
363363 // load model
364364 std::string tmpFolder = AC_TEST_DATA_LLAMA_DIR " /../../../tmp/" ;
365- // std::string modelGguf = "Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf";
366- std::string modelGguf = " Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf" ;
365+ std::string modelGguf = " Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf" ;
366+ // std::string modelGguf = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
367367 // std::string modelGguf = "BgGPT-Gemma-2-2B-IT-v1.0.Q8_0.gguf";
368368 // std::string modelGguf = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
369- std::string modelGguf2 = " Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf" ;
370- // std::string modelGguf2 = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
369+ // std::string modelGguf2 = "Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf";
370+ std::string modelGguf2 = " Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf" ;
371371
372- std::string prompt = " The first person to" ;
372+ // std::string prompt = "The first person to ";
373+ std::string prompt = " Explain quantum physics in simple terms." ;
373374 std::cout << " Prompt: " << prompt << " \n " ;
374375
375- #if 0
376- Model m1(tmpFolder + modelGguf, {});
377- Model m2(tmpFolder + modelGguf2, {});
378-
379- auto genRes = modelGeneration(m1, m2, prompt, 100);
380- auto& r1 = genRes[0];
381- auto& r2 = genRes[1];
382- serialize(modelGguf, r1);
383- serialize(modelGguf2, r2);
384- #else
385376 std::string res1fn = " gen-res_" + modelGguf + " _" + prompt + " .bin" ;
386377 std::string res2fn = " gen-res_" + modelGguf2 + " _" + prompt + " .bin" ;
387-
388- auto r1 = deserialize (res1fn);
389- auto r2 = deserialize (res2fn);
390- #endif
378+ bool shouldRunGenerate = !(std::filesystem::exists (res1fn) && std::filesystem::exists (res2fn));
379+
380+ Model::GenerationResult r1;
381+ Model::GenerationResult r2;
382+ if (shouldRunGenerate) {
383+ Model m1 (tmpFolder + modelGguf, {});
384+ Model m2 (tmpFolder + modelGguf2, {});
385+
386+ auto genRes = modelGeneration (m1, m2, prompt, 100 );
387+ r1 = std::move (genRes[0 ]);
388+ r2 = std::move (genRes[1 ]);
389+ serialize (res1fn, modelGguf, r1);
390+ serialize (res2fn, modelGguf2, r2);
391+ } else {
392+ r1 = deserialize (res1fn);
393+ r2 = deserialize (res2fn);
394+ }
391395
392396 std::string result = prompt;
393397
394398 std::cout << " Models to compare:\n " << modelGguf << " \n " << modelGguf2 << " \n " ;
395399 std::cout << " Comparing...\n " ;
396400
397- runTest (r1, r2);
401+ runCompare (r1, r2);
398402
399403 return 0 ;
400404}
0 commit comments