|
| 1 | +// Copyright (c) Alpaca Core |
| 2 | +// SPDX-License-Identifier: MIT |
| 3 | +// |
| 4 | + |
| 5 | +// trivial example of using alpaca-core's llama inference |
| 6 | + |
| 7 | +// llama |
| 8 | +#include <ac/llama/Init.hpp> |
| 9 | +#include <ac/llama/Model.hpp> |
| 10 | +#include <ac/llama/Instance.hpp> |
| 11 | +#include <ac/llama/Session.hpp> |
| 12 | +#include <ac/llama/ControlVector.hpp> |
| 13 | +#include <ac/llama/ResourceCache.hpp> |
| 14 | +#include <ac/llama/LogitComparer.hpp> |
| 15 | + |
| 16 | +// logging |
| 17 | +#include <ac/jalog/Instance.hpp> |
| 18 | +#include <ac/jalog/sinks/ColorSink.hpp> |
| 19 | + |
| 20 | +// model source directory |
| 21 | +#include "ac-test-data-llama-dir.h" |
| 22 | + |
| 23 | +#include <iostream> |
| 24 | +#include <string> |
| 25 | + |
| 26 | +struct GenerationStepData { |
| 27 | + std::string tokenStr; |
| 28 | + int32_t token; |
| 29 | + ac::llama::TokenDataVector data; |
| 30 | +}; |
| 31 | + |
| 32 | +ac::local::ResourceManager g_rm; |
| 33 | +ac::llama::ResourceCache g_rcache(g_rm); |
| 34 | + |
| 35 | +class Model { |
| 36 | +public: |
| 37 | + Model(const std::string& gguf, ac::llama::Model::Params params) { |
| 38 | + m_model = g_rcache.getModel({.gguf = gguf, .params = {params}}); |
| 39 | + m_instance.reset(new ac::llama::Instance(*m_model, { |
| 40 | + .ctxSize = 2048, |
| 41 | + })); |
| 42 | + } |
| 43 | + |
| 44 | + struct GenerationResult { |
| 45 | + std::string initalPrompt; |
| 46 | + std::string result; |
| 47 | + std::vector<GenerationStepData> steps; |
| 48 | + }; |
| 49 | + |
| 50 | + GenerationResult generate(std::string prompt, uint32_t maxTokens) { |
| 51 | + m_session = &m_instance->startSession({}); |
| 52 | + |
| 53 | + auto promptTokens = m_model->vocab().tokenize(prompt, true, true); |
| 54 | + m_session->setInitialPrompt(promptTokens); |
| 55 | + |
| 56 | + constexpr int32_t topK = 10; |
| 57 | + auto data = m_session->getSampledTokenData(topK); |
| 58 | + |
| 59 | + auto token = promptTokens.back(); |
| 60 | + auto tokenStr = m_model->vocab().tokenToString(token); |
| 61 | + |
| 62 | + std::vector<GenerationStepData> genSteps; |
| 63 | + genSteps.push_back(GenerationStepData{ |
| 64 | + .tokenStr = tokenStr, |
| 65 | + .token = token, |
| 66 | + .data = std::move(data) |
| 67 | + }); |
| 68 | + |
| 69 | + std::string result = ""; |
| 70 | + for (size_t i = 0; i < maxTokens; i++) { |
| 71 | + auto token = m_session->getToken(); |
| 72 | + if (token == ac::llama::Token_Invalid) { |
| 73 | + // no more tokens |
| 74 | + break; |
| 75 | + } |
| 76 | + tokenStr = m_model->vocab().tokenToString(token); |
| 77 | + result += tokenStr; |
| 78 | + |
| 79 | + auto data = m_session->getSampledTokenData(topK); |
| 80 | + |
| 81 | + genSteps.push_back({ |
| 82 | + .tokenStr = tokenStr, |
| 83 | + .token = token, |
| 84 | + .data = std::move(data) |
| 85 | + }); |
| 86 | + } |
| 87 | + |
| 88 | + m_instance->stopSession(); |
| 89 | + m_session = nullptr; |
| 90 | + |
| 91 | + return { |
| 92 | + .initalPrompt = prompt, |
| 93 | + .result = result, |
| 94 | + .steps = genSteps |
| 95 | + }; |
| 96 | + } |
| 97 | + |
| 98 | +private: |
| 99 | + ac::llama::ResourceCache::ModelLock m_model; |
| 100 | + std::unique_ptr<ac::llama::Instance> m_instance; |
| 101 | + ac::llama::Session* m_session; |
| 102 | +}; |
| 103 | + |
| 104 | +int main() try { |
| 105 | + ac::jalog::Instance jl; |
| 106 | + jl.setup().add<ac::jalog::sinks::ColorSink>(); |
| 107 | + |
| 108 | + // initialize the library |
| 109 | + ac::llama::initLibrary(); |
| 110 | + |
| 111 | + std::vector<GenerationStepData> genSteps; |
| 112 | + |
| 113 | + // load model |
| 114 | + std::string tmpFolder = AC_TEST_DATA_LLAMA_DIR "/../../../tmp/"; |
| 115 | + std::string modelGguf = "Meta-Llama-3.1-70B-Instruct-Q5_K_S.gguf"; |
| 116 | + std::string modelGguf2 = "Meta-Llama-3.1-8B-Instruct-Q5_K_S.gguf"; |
| 117 | + |
| 118 | + Model m1(tmpFolder + modelGguf, {}); |
| 119 | + Model m2(tmpFolder + modelGguf2, {}); |
| 120 | + |
| 121 | + std::string prompt = "The first person to"; |
| 122 | + std::cout << "Prompt: " << prompt << "\n"; |
| 123 | + |
| 124 | + std::string result = prompt; |
| 125 | + |
| 126 | + std::cout << "Models to compare:\n" << modelGguf << "\n" << modelGguf2 << "\n"; |
| 127 | + std::cout << "Comparing...\n"; |
| 128 | + |
| 129 | + for (int i = 0; i < 1; ++i) { |
| 130 | + |
| 131 | + auto res = m1.generate(prompt, 100); |
| 132 | + std::cout << "Model 1 generated: " << res.result << "\n"; |
| 133 | + std::string genPrompt = res.initalPrompt; |
| 134 | + for (size_t i = 0; i < res.steps.size(); i++) { |
| 135 | + auto& step = res.steps[i]; |
| 136 | + if (i > 0) { |
| 137 | + genPrompt += step.tokenStr; |
| 138 | + } |
| 139 | + auto res2 = m2.generate(genPrompt, 0); |
| 140 | + assert(res2.steps.size() == 1); |
| 141 | + |
| 142 | + if (ac::llama::LogitComparer::compare(step.data, res2.steps[0].data)) { |
| 143 | + std::cout << "Models are the same. Generated str by now:\n" << genPrompt << "\n\n"; |
| 144 | + } |
| 145 | + } |
| 146 | + } |
| 147 | + std::cout << '\n'; |
| 148 | + |
| 149 | + return 0; |
| 150 | +} |
| 151 | +catch (const std::exception& e) { |
| 152 | + std::cerr << "Error: " << e.what() << std::endl; |
| 153 | + return 1; |
| 154 | +} |
0 commit comments