2626#include < iostream>
2727#include < string>
2828
29- std::vector<std::string> g_fruits = {
30- " Apple" , " Banana" , " Orange" , " Strawberry" , " Blueberry" , " Raspberry" , " Blackberry" , " Pineapple" ,
31- " Mango" , " Peach" , " Plum" , " Cherry" , " Grapes" , " Watermelon" , " Cantaloupe" , " Honeydew" , " Kiwi" ,
32- " Pomegranate" , " Papaya" , " Fig" , " Dragon Fruit" , " Lychee" , " Coconut" , " Guava" , " Passion Fruit" ,
33- " Lemon" , " Lime" , " Pear" , " Cranberry" , " Apricot"
34- };
29+
3530
3631std::vector<std::pair<std::string, std::string>> g_recipes = {
3732 {" Apple Banana" , " Apple Banana Smoothie: Blend 1 apple, 1 banana, 1 cup milk, and 1 tbsp honey until smooth." },
@@ -159,6 +154,9 @@ class VectorDatabase {
159154
160155};
161156
157+ // This is a simple object for storing a document
158+ // You can modify it to store custom information if it's needed
159+ // Note: Do not forget to implement the write and read functions for your custom object
162160struct Document {
163161 std::string content;
164162 void write (std::ostream& out) const {
@@ -176,39 +174,50 @@ struct Document {
176174 }
177175};
178176
177+ std::string retrieveKnowledge (const std::string& query, VectorDatabase<Document>& vdb) {
178+ // Search for the most relevant recipes
179+ // in the vector database
180+ int k = 3 ; // Retrieve top-3 results
181+ std::vector<float > seachEmbedding = g_embeddingInstance->getEmbeddingVector (g_embeddingInstance->model ().vocab ().tokenize (query, true , true ));
182+ auto results = vdb.searchKnn (seachEmbedding, k);
183+
184+ std::string knowledgeContent = " " ;
185+
186+ int cnt = 1 ;
187+ std::cout << " ============= database context =============\n " ;
188+ while (!results.empty ()) {
189+ auto res = results.top ();
190+ std::cout << " \t " << cnt << " ."
191+ << " Distance: " << res.dist
192+ << " ID:" << res.idx
193+ << " content: " << res.content .content
194+ << std::endl;
195+ results.pop ();
196+ knowledgeContent += std::to_string (cnt++) + " : " + res.content .content + " \n " ;
197+ }
198+ std::cout << " =============================================\n " ;
199+
200+ return knowledgeContent;
201+ }
202+
179203std::string generateResponse (ac::llama::Session& session, const std::string& prompt, VectorDatabase<Document>& vdb, int maxTokens = 512 ) {
180204 ac::llama::ChatFormat chatFormat (" llama3" );
181205 ac::llama::ChatMsg msg{.text = prompt, .role = " user" };
182206
183- {
184- // Search for the most relevant recipe
185- // in the vector database
186- int k = 3 ; // Retrieve top-3 results
187- std::vector<float > seachEmbedding = g_embeddingInstance->getEmbeddingVector (g_embeddingInstance->model ().vocab ().tokenize (prompt, true , true ));
188- auto result = vdb.searchKnn (seachEmbedding, k);
189-
190- std::string knowledge = " You are a recipe assistant. Given the following relevant recipes, select the most relevant one or paraphrase it:\n " ;
191- std::cout << " Nearest neighbors:\n " ;
192- int cnt = 1 ;
193- while (!result.empty ()) {
194- auto res = result.top ();
195- std::cout<<" \t Distance: " << res.dist
196- << " ID:" << res.idx
197- << " content: " << res.content .content << std::endl;
198- result.pop ();
199- knowledge += std::to_string (cnt++) + " : " + res.content .content + " \n " ;
200- }
201-
202- session.pushPrompt (g_chatInstance->model ().vocab ().tokenize (knowledge, false , false ));
203- }
207+ // 1. Fill the context with the relevant recipes
208+ const std::string systemPrompt = " You are a recipe assistant. Given the following relevant recipes, select the most relevant one or paraphrase it:\n " ;
209+ const std::string knowledge = retrieveKnowledge (prompt, vdb);
210+ session.pushPrompt (g_chatInstance->model ().vocab ().tokenize (systemPrompt + knowledge, false , false ));
204211
212+ // 2. Add the user prompt to the context
205213 auto formatted = chatFormat.formatMsg (msg, g_messages, true );
206214 g_messages.emplace_back (msg);
207- // auto formatted = chatFormat.formatChat(g_messages, true);
215+ // Note: To format the full chat and push it into the context uncomment the following line
216+ // formatted = chatFormat.formatChat(g_messages, true);
208217 session.pushPrompt (g_chatInstance->model ().vocab ().tokenize (formatted, false , false ));
209218
219+ // 3. Generate the response
210220 std::string response = " " ;
211-
212221 for (int i = 0 ; i < maxTokens; ++i) {
213222 auto token = session.getToken ();
214223 if (token == ac::llama::Token_Invalid) {
@@ -243,32 +252,20 @@ int main() try {
243252 // initialize the library
244253 ac::llama::initLibrary ();
245254
246- // This model won't work for this example, but it's a placeholder
247- // Download better model - llama3.2 8b for example
255+ // Note: This model won't work for this example, but it's a placeholder.
256+ // Download better model - llama3.2 8b for example
248257 std::string modelGguf = AC_TEST_DATA_LLAMA_DIR " /gpt2-117m-q6_k.gguf" ;
249258 std::string embeddingModelGguf = AC_TEST_DATA_LLAMA_DIR " /bge-small-en-v1.5-f16.gguf" ;
250- auto modelLoadProgressCallback = [](float progress) {
251- const int barWidth = 50 ;
252- static float currProgress = 0 ;
253- auto delta = int (progress * barWidth) - int (currProgress * barWidth);
254- for (int i = 0 ; i < delta; i++) {
255- std::cout.put (' =' );
256- }
257- currProgress = progress;
258- if (progress == 1 .f ) {
259- std::cout << ' \n ' ;
260- }
261- return true ;
262- };
263259
264260 ac::llama::ResourceCache cache;
265- // create inference objects
266- auto model = cache.getModel ({.gguf = modelGguf, .params = {}}, modelLoadProgressCallback);
261+
262+ // create objects for the inference
263+ auto model = cache.getModel ({.gguf = modelGguf});
267264 ac::llama::Instance instance (*model, {});
268265 g_chatInstance = &instance;
269266
270- // create embedding objects
271- auto mEmbedding = cache.getModel ({.gguf = embeddingModelGguf, . params = {}}, modelLoadProgressCallback );
267+ // create objects for the embedding
268+ auto mEmbedding = cache.getModel ({.gguf = embeddingModelGguf} );
272269 ac::llama::InstanceEmbedding iEmbedding (*mEmbedding , {});
273270 g_embeddingInstance = &iEmbedding;
274271
0 commit comments