Skip to content

Commit b1a90ab

Browse files
committed
feat: add support for infill generation, ref #62
1 parent 1a4473c commit b1a90ab

File tree

6 files changed

+144
-14
lines changed

6 files changed

+144
-14
lines changed

ac-local-plugin/code/LocalLlama.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,19 @@ struct LocalLlama {
207207

208208
sc::StateGeneralInstance::OpRun::Return opRun(llama::Instance& instance, const sc::StateGeneralInstance::OpRun::Params& iparams) {
209209
auto& prompt = iparams.prompt.value();
210+
auto& suffix = iparams.suffix.value();
210211
auto maxTokens = iparams.maxTokens.valueOr(0);
211212

212213
auto& session = instance.startSession({});
213214

214215
auto promptTokens = instance.model().vocab().tokenize(prompt, true, true);
215-
session.setInitialPrompt(promptTokens);
216+
if (suffix.empty()) {
217+
session.setInitialPrompt(promptTokens);
218+
} else{
219+
auto suffixTokens = instance.model().vocab().tokenize(suffix, true, true);
220+
session.setInitialPrompt({});
221+
session.pushPrompt(promptTokens, suffixTokens);
222+
}
216223

217224
ac::llama::AntipromptManager antiprompt;
218225
for (auto& ap : iparams.antiprompts.value()) {
@@ -247,12 +254,19 @@ struct LocalLlama {
247254
const sc::StateGeneralInstance::OpStream::Params& iparams) {
248255

249256
auto& prompt = iparams.prompt.value();
257+
auto& suffix = iparams.suffix.value();
250258
auto maxTokens = iparams.maxTokens.valueOr(0);
251259

252260
auto& session = instance.startSession({});
253261

254262
auto promptTokens = instance.model().vocab().tokenize(prompt, true, true);
255-
session.setInitialPrompt(promptTokens);
263+
if (suffix.empty()) {
264+
session.setInitialPrompt(promptTokens);
265+
} else{
266+
auto suffixTokens = instance.model().vocab().tokenize(suffix, true, true);
267+
session.setInitialPrompt({});
268+
session.pushPrompt(promptTokens, suffixTokens);
269+
}
256270

257271
ac::llama::AntipromptManager antiprompt;
258272
for (auto& ap : iparams.antiprompts.value()) {

ac-local-plugin/schema/ac/schema/LlamaCpp.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,14 @@ struct StateGeneralInstance {
102102

103103
struct InferenceParams {
104104
Field<std::string> prompt;
105+
Field<std::string> suffix = Default();
105106
Field<std::vector<std::string>> antiprompts = Default();
106107
Field<uint32_t> maxTokens = Default(0);
107108

108109
template <typename Visitor>
109110
void visitFields(Visitor& v) {
110111
v(prompt, "prompt", "Prompt to complete");
112+
v(suffix, "suffix", "Suffix of the prompt. Used for infill (code generation for example");
111113
v(antiprompts, "antiprompts", "Antiprompts to trigger stop");
112114
v(maxTokens, "max_tokens", "Maximum number of tokens to generate. 0 for unlimited");
113115
}

code/ac/llama/Session.cpp

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,28 +86,65 @@ void Session::setInitialPrompt(std::span<const Token> initialPrompt) {
8686
m_state.m_phase = State::Phase::Generating;
8787
}
8888

89-
void Session::pushPrompt(std::span<const Token> prompt) {
89+
void Session::pushPrompt(std::span<const Token> prompt, std::span<const Token> postfix) {
9090
if (m_state.m_phase != State::Phase::Generating) {
9191
throw_ex{} << "Session hasn't started yet";
9292
}
9393

9494
flushPendingState();
9595

96-
if (!prompt.empty()) {
97-
auto& sampler = m_instance.sampler();
98-
auto& model = m_instance.model();
96+
if (prompt.empty() && postfix.empty()) {
97+
throw_ex{} << "Prompt and postfix are empty";
98+
}
99+
100+
auto& model = m_instance.model();
101+
auto& sampler = m_instance.sampler();
102+
103+
// reset sampling and don't allow previous inputs to affect the generation
104+
sampler.reset();
105+
106+
std::vector<Token> tokens;
107+
constexpr uint32_t maxAdditionalTokens = 4; // bos + fim_pre + fim_suf + fim_mid
108+
tokens.reserve(prompt.size() + postfix.size() + maxAdditionalTokens);
99109

100-
// reset sampling and don't allow previous inputs to affect the generation
101-
sampler.reset();
110+
if (model.prefixInputsWithBos()) {
111+
const auto tokenBos = llama_vocab_bos(model.vocab().lvocab());
112+
tokens.push_back(tokenBos);
113+
}
102114

103-
if (model.prefixInputsWithBos()) {
104-
const auto tokenBos = llama_vocab_bos(model.vocab().lvocab());
105-
// add bos token to the prompt
106-
doDecode({&tokenBos, 1}, Source::InteractivePrompt);
115+
auto safeAddToken = [&](Token token, const std::string& tokenName) {
116+
if (token >= 0) {
117+
tokens.push_back(token);
118+
} else {
119+
LLAMA_LOG(Warning, "Model doesn't have a ", tokenName," token");
107120
}
121+
};
122+
123+
if (!postfix.empty()) {
124+
auto tokenFIMPre = llama_vocab_fim_pre(model.vocab().lvocab());
125+
safeAddToken(tokenFIMPre, "FIM Prefix");
126+
}
127+
128+
if (!prompt.empty()) {
129+
tokens.insert(tokens.end(), prompt.begin(), prompt.end());
130+
}
131+
132+
if (!postfix.empty()) {
133+
auto tokenFIMSuff = llama_vocab_fim_suf(model.vocab().lvocab());
134+
safeAddToken(tokenFIMSuff, "FIM Suffix");
108135

109-
doDecode(prompt, Source::InteractivePrompt);
136+
tokens.insert(tokens.end(), postfix.begin(), postfix.end());
137+
138+
auto tkoenFIMMid = llama_vocab_fim_mid(model.vocab().lvocab());
139+
safeAddToken(tkoenFIMMid, "FIM Middle");
110140
}
141+
142+
if (tokens.size() > m_state.maxTokens) {
143+
const auto ctxLen = llama_n_ctx(m_ctx);
144+
throw_ex{} << "Prompt too long. Got " << tokens.size() << " tokens, max: " << ctxLen - 4;
145+
}
146+
147+
doDecode(tokens, Source::InteractivePrompt);
111148
}
112149

113150
Token Session::getToken() {

code/ac/llama/Session.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class Session {
3535
bool setState(std::span<uint8_t> state);
3636

3737
// main functions to interact with the model
38-
void pushPrompt(std::span<const Token> prompt);
38+
void pushPrompt(std::span<const Token> prompt, std::span<const Token> postfix = {});
3939
Token getToken();
4040
TokenDataVector getSampledTokenData(int32_t topK, float topP = 0.95f);
4141
std::vector<uint8_t> getState();

example/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ endfunction()
1313

1414
add_example(basic)
1515
add_example(embedding)
16+
add_example(infill)
1617

1718
CPMAddPackage(gh:alpaca-core/[email protected])
1819
if(TARGET ac-dev::imgui-sdl-app)

example/e-infill.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright (c) Alpaca Core
2+
// SPDX-License-Identifier: MIT
3+
//
4+
5+
// Code completion Example of using alpaca-core's llama inference
6+
7+
// llama
8+
#include <ac/llama/Init.hpp>
9+
#include <ac/llama/Model.hpp>
10+
#include <ac/llama/Instance.hpp>
11+
#include <ac/llama/Session.hpp>
12+
#include <ac/llama/ResourceCache.hpp>
13+
14+
// logging
15+
#include <ac/jalog/Instance.hpp>
16+
#include <ac/jalog/sinks/ColorSink.hpp>
17+
18+
// model source directory
19+
#include "ac-test-data-llama-dir.h"
20+
21+
#include <iostream>
22+
#include <string>
23+
24+
int main() try {
25+
ac::jalog::Instance jl;
26+
jl.setup().add<ac::jalog::sinks::ColorSink>();
27+
28+
// initialize the library
29+
ac::llama::initLibrary();
30+
31+
// load model
32+
// download better model for good code completion results such as
33+
// https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct-GGUF/tree/main
34+
// std::string modelGguf = AC_TEST_DATA_LLAMA_DIR "/../../../models/qwen2.5-coder-3b-instruct-q8_0.gguf";
35+
std::string modelGguf = AC_TEST_DATA_LLAMA_DIR "/gpt2-117m-q6_k.gguf";
36+
37+
ac::local::ResourceManager rm;
38+
ac::llama::ResourceCache cache(rm);
39+
auto model = cache.getModel({.gguf = modelGguf, .params = {}});
40+
41+
// create inference instance
42+
ac::llama::Instance instance(*model, {});
43+
44+
// start session
45+
auto& session = instance.startSession({});
46+
session.setInitialPrompt({});
47+
48+
std::string input_prefix = "def helloworld():\n print(\"hell";
49+
std::string input_suffix = "\n print(\"goodbye world\")\n";
50+
std::cout << "<prefix>\n" << input_prefix << "\n</prefix> +\n <place_to_fill> + \n" << "<postfix>\n" << input_suffix << "\n</postfix>\n";
51+
52+
session.pushPrompt(
53+
model->vocab().tokenize(input_prefix, true, true),
54+
model->vocab().tokenize(input_suffix, true, true));
55+
56+
std::cout << "Final result: \n" << input_prefix;
57+
58+
// generate and print 100 tokens
59+
for (int i = 0; i < 100; ++i) {
60+
auto token = session.getToken();
61+
if (token == ac::llama::Token_Invalid) {
62+
// no more tokens
63+
break;
64+
}
65+
66+
auto str = model->vocab().tokenToString(token);
67+
std::cout << str;
68+
}
69+
std::cout << input_suffix << "\n";
70+
71+
return 0;
72+
}
73+
catch (const std::exception& e) {
74+
std::cerr << "Error: " << e.what() << std::endl;
75+
return 1;
76+
}

0 commit comments

Comments
 (0)