Skip to content

Commit 1a4473c

Browse files
committed
chore: cleanup, ref #58
1 parent 82a98a5 commit 1a4473c

File tree

1 file changed

+45
-48
lines changed

1 file changed

+45
-48
lines changed

example/e-rag.cpp

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,7 @@
2626
#include <iostream>
2727
#include <string>
2828

29-
std::vector<std::string> g_fruits = {
30-
"Apple", "Banana", "Orange", "Strawberry", "Blueberry", "Raspberry", "Blackberry", "Pineapple",
31-
"Mango", "Peach", "Plum", "Cherry", "Grapes", "Watermelon", "Cantaloupe", "Honeydew", "Kiwi",
32-
"Pomegranate", "Papaya", "Fig", "Dragon Fruit", "Lychee", "Coconut", "Guava", "Passion Fruit",
33-
"Lemon", "Lime", "Pear", "Cranberry", "Apricot"
34-
};
29+
3530

3631
std::vector<std::pair<std::string, std::string>> g_recipes = {
3732
{"Apple Banana", "Apple Banana Smoothie: Blend 1 apple, 1 banana, 1 cup milk, and 1 tbsp honey until smooth."},
@@ -159,6 +154,9 @@ class VectorDatabase {
159154

160155
};
161156

157+
// This is a simple object for storing a document
158+
// You can modify it to store custom information if it's needed
159+
// Note: Do not forget to implement the write and read functions for your custom object
162160
struct Document {
163161
std::string content;
164162
void write(std::ostream& out) const {
@@ -176,39 +174,50 @@ struct Document {
176174
}
177175
};
178176

177+
std::string retrieveKnowledge(const std::string& query, VectorDatabase<Document>& vdb) {
178+
// Search for the most relevant recipes
179+
// in the vector database
180+
int k = 3; // Retrieve top-3 results
181+
std::vector<float> seachEmbedding = g_embeddingInstance->getEmbeddingVector(g_embeddingInstance->model().vocab().tokenize(query, true, true));
182+
auto results = vdb.searchKnn(seachEmbedding, k);
183+
184+
std::string knowledgeContent = "";
185+
186+
int cnt = 1;
187+
std::cout << "============= database context =============\n";
188+
while (!results.empty()) {
189+
auto res = results.top();
190+
std::cout << "\t" << cnt << "."
191+
<< " Distance: " << res.dist
192+
<< " ID:" << res.idx
193+
<< " content: " << res.content.content
194+
<< std::endl;
195+
results.pop();
196+
knowledgeContent += std::to_string(cnt++) + ": " + res.content.content + "\n";
197+
}
198+
std::cout << "=============================================\n";
199+
200+
return knowledgeContent;
201+
}
202+
179203
std::string generateResponse(ac::llama::Session& session, const std::string& prompt, VectorDatabase<Document>& vdb, int maxTokens = 512) {
180204
ac::llama::ChatFormat chatFormat("llama3");
181205
ac::llama::ChatMsg msg{.text = prompt, .role = "user"};
182206

183-
{
184-
// Search for the most relevant recipe
185-
// in the vector database
186-
int k = 3; // Retrieve top-3 results
187-
std::vector<float> seachEmbedding = g_embeddingInstance->getEmbeddingVector(g_embeddingInstance->model().vocab().tokenize(prompt, true, true));
188-
auto result = vdb.searchKnn(seachEmbedding, k);
189-
190-
std::string knowledge = "You are a recipe assistant. Given the following relevant recipes, select the most relevant one or paraphrase it:\n";
191-
std::cout << "Nearest neighbors:\n";
192-
int cnt = 1;
193-
while (!result.empty()) {
194-
auto res = result.top();
195-
std::cout<<"\tDistance: " << res.dist
196-
<< " ID:" << res.idx
197-
<< " content: " << res.content.content << std::endl;
198-
result.pop();
199-
knowledge += std::to_string(cnt++) + ": " + res.content.content + "\n";
200-
}
201-
202-
session.pushPrompt(g_chatInstance->model().vocab().tokenize(knowledge, false, false));
203-
}
207+
// 1. Fill the context with the relevant recipes
208+
const std::string systemPrompt = "You are a recipe assistant. Given the following relevant recipes, select the most relevant one or paraphrase it:\n";
209+
const std::string knowledge = retrieveKnowledge(prompt, vdb);
210+
session.pushPrompt(g_chatInstance->model().vocab().tokenize(systemPrompt + knowledge, false, false));
204211

212+
// 2. Add the user prompt to the context
205213
auto formatted = chatFormat.formatMsg(msg, g_messages, true);
206214
g_messages.emplace_back(msg);
207-
// auto formatted = chatFormat.formatChat(g_messages, true);
215+
// Note: To format the full chat and push it into the context uncomment the following line
216+
// formatted = chatFormat.formatChat(g_messages, true);
208217
session.pushPrompt(g_chatInstance->model().vocab().tokenize(formatted, false, false));
209218

219+
// 3. Generate the response
210220
std::string response = "";
211-
212221
for (int i = 0; i < maxTokens; ++i) {
213222
auto token = session.getToken();
214223
if (token == ac::llama::Token_Invalid) {
@@ -243,32 +252,20 @@ int main() try {
243252
// initialize the library
244253
ac::llama::initLibrary();
245254

246-
// This model won't work for this example, but it's a placeholder
247-
// Download better model - llama3.2 8b for example
255+
// Note: This model won't work for this example, but it's a placeholder.
256+
// Download better model - llama3.2 8b for example
248257
std::string modelGguf = AC_TEST_DATA_LLAMA_DIR "/gpt2-117m-q6_k.gguf";
249258
std::string embeddingModelGguf = AC_TEST_DATA_LLAMA_DIR "/bge-small-en-v1.5-f16.gguf";
250-
auto modelLoadProgressCallback = [](float progress) {
251-
const int barWidth = 50;
252-
static float currProgress = 0;
253-
auto delta = int(progress * barWidth) - int(currProgress * barWidth);
254-
for (int i = 0; i < delta; i++) {
255-
std::cout.put('=');
256-
}
257-
currProgress = progress;
258-
if (progress == 1.f) {
259-
std::cout << '\n';
260-
}
261-
return true;
262-
};
263259

264260
ac::llama::ResourceCache cache;
265-
// create inference objects
266-
auto model = cache.getModel({.gguf = modelGguf, .params = {}}, modelLoadProgressCallback);
261+
262+
// create objects for the inference
263+
auto model = cache.getModel({.gguf = modelGguf});
267264
ac::llama::Instance instance(*model, {});
268265
g_chatInstance = &instance;
269266

270-
// create embedding objects
271-
auto mEmbedding = cache.getModel({.gguf = embeddingModelGguf, .params = {}}, modelLoadProgressCallback);
267+
// create objects for the embedding
268+
auto mEmbedding = cache.getModel({.gguf = embeddingModelGguf});
272269
ac::llama::InstanceEmbedding iEmbedding(*mEmbedding, {});
273270
g_embeddingInstance = &iEmbedding;
274271

0 commit comments

Comments
 (0)