Skip to content

Commit a8c0665

Browse files
committed
feat: add example to verify models
1 parent cb33cd2 commit a8c0665

File tree

2 files changed

+155
-0
lines changed

2 files changed

+155
-0
lines changed

example/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ endfunction()
1414
add_example(basic)
1515
add_example(embedding)
1616
add_example(infill)
17+
add_example(verify)
1718

1819
CPMAddPackage(gh:alpaca-core/[email protected])
1920
if(TARGET ac-dev::imgui-sdl-app)

example/e-verify.cpp

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright (c) Alpaca Core
2+
// SPDX-License-Identifier: MIT
3+
//
4+
5+
// trivial example of using alpaca-core's llama inference
6+
7+
// llama
8+
#include <ac/llama/Init.hpp>
9+
#include <ac/llama/Model.hpp>
10+
#include <ac/llama/Instance.hpp>
11+
#include <ac/llama/Session.hpp>
12+
#include <ac/llama/ControlVector.hpp>
13+
#include <ac/llama/ResourceCache.hpp>
14+
#include <ac/llama/LogitComparer.hpp>
15+
16+
// logging
17+
#include <ac/jalog/Instance.hpp>
18+
#include <ac/jalog/sinks/ColorSink.hpp>
19+
20+
// model source directory
21+
#include "ac-test-data-llama-dir.h"
22+
23+
#include <iostream>
24+
#include <string>
25+
26+
struct GenerationStepData {
27+
std::string tokenStr;
28+
int32_t token;
29+
ac::llama::TokenDataVector data;
30+
};
31+
32+
ac::local::ResourceManager g_rm;
33+
ac::llama::ResourceCache g_rcache(g_rm);
34+
35+
class Model {
36+
public:
37+
Model(const std::string& gguf, ac::llama::Model::Params params) {
38+
m_model = g_rcache.getModel({.gguf = gguf, .params = {params}});
39+
m_instance.reset(new ac::llama::Instance(*m_model, {
40+
.ctxSize = 2048,
41+
}));
42+
}
43+
44+
struct GenerationResult {
45+
std::string initalPrompt;
46+
std::string result;
47+
std::vector<GenerationStepData> steps;
48+
};
49+
50+
GenerationResult generate(std::string prompt, uint32_t maxTokens) {
51+
m_session = &m_instance->startSession({});
52+
53+
auto promptTokens = m_model->vocab().tokenize(prompt, true, true);
54+
m_session->setInitialPrompt(promptTokens);
55+
56+
constexpr int32_t topK = 10;
57+
auto data = m_session->getSampledTokenData(topK);
58+
59+
auto token = promptTokens.back();
60+
auto tokenStr = m_model->vocab().tokenToString(token);
61+
62+
std::vector<GenerationStepData> genSteps;
63+
genSteps.push_back(GenerationStepData{
64+
.tokenStr = tokenStr,
65+
.token = token,
66+
.data = std::move(data)
67+
});
68+
69+
std::string result = "";
70+
for (size_t i = 0; i < maxTokens; i++) {
71+
auto token = m_session->getToken();
72+
if (token == ac::llama::Token_Invalid) {
73+
// no more tokens
74+
break;
75+
}
76+
tokenStr = m_model->vocab().tokenToString(token);
77+
result += tokenStr;
78+
79+
auto data = m_session->getSampledTokenData(topK);
80+
81+
genSteps.push_back({
82+
.tokenStr = tokenStr,
83+
.token = token,
84+
.data = std::move(data)
85+
});
86+
}
87+
88+
m_instance->stopSession();
89+
m_session = nullptr;
90+
91+
return {
92+
.initalPrompt = prompt,
93+
.result = result,
94+
.steps = genSteps
95+
};
96+
}
97+
98+
private:
99+
ac::llama::ResourceCache::ModelLock m_model;
100+
std::unique_ptr<ac::llama::Instance> m_instance;
101+
ac::llama::Session* m_session;
102+
};
103+
104+
int main() try {
105+
ac::jalog::Instance jl;
106+
jl.setup().add<ac::jalog::sinks::ColorSink>();
107+
108+
// initialize the library
109+
ac::llama::initLibrary();
110+
111+
std::vector<GenerationStepData> genSteps;
112+
113+
// load model
114+
std::string tmpFolder = AC_TEST_DATA_LLAMA_DIR "/../../../tmp/";
115+
std::string modelGguf = "Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf";
116+
std::string modelGguf2 = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf";
117+
118+
Model m1(tmpFolder + modelGguf, {});
119+
Model m2(tmpFolder + modelGguf2, {});
120+
121+
std::string prompt = "The first person to";
122+
std::cout << "Prompt: " << prompt << "\n";
123+
124+
std::string result = prompt;
125+
126+
std::cout << "Models to compare:\n" << modelGguf << "\n" << modelGguf2 << "\n";
127+
std::cout << "Comparing...\n";
128+
129+
for (int i = 0; i < 1; ++i) {
130+
131+
auto res = m1.generate(prompt, 100);
132+
std::cout << "Model 1 generated: " << res.result << "\n";
133+
std::string genPrompt = res.initalPrompt;
134+
for (size_t i = 0; i < res.steps.size(); i++) {
135+
auto& step = res.steps[i];
136+
if (i > 0) {
137+
genPrompt += step.tokenStr;
138+
}
139+
auto res2 = m2.generate(genPrompt, 0);
140+
assert(res2.steps.size() == 1);
141+
142+
if (ac::llama::LogitComparer::compare(step.data, res2.steps[0].data)) {
143+
std::cout << "Models are the same. Generated str by now:\n" << genPrompt << "\n\n";
144+
}
145+
}
146+
}
147+
std::cout << '\n';
148+
149+
return 0;
150+
}
151+
catch (const std::exception& e) {
152+
std::cerr << "Error: " << e.what() << std::endl;
153+
return 1;
154+
}

0 commit comments

Comments
 (0)