Skip to content

Commit c1a2ad0

Browse files
wwoodsTMl3utterfly
andcommitted
DRY: Added co-author, testing, fix for consistent prompt sampling
Co-authored-by: l3utterfly <[email protected]>
1 parent 52530f2 commit c1a2ad0

File tree

14 files changed

+182
-93
lines changed

14 files changed

+182
-93
lines changed

common/sampling.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ std::string gpt_sampler_params::print() const {
145145
return std::string(result);
146146
}
147147

148-
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
148+
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params, int32_t context_size) {
149149
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
150150

151151
lparams.no_perf = params.no_perf;
@@ -180,7 +180,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
180180
params.ignore_eos));
181181

182182
if (params.dry_multiplier != 0.0f && params.dry_base != 0.0f) {
183-
auto * dry_sampler = llama_sampler_init_dry(model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n);
183+
auto * dry_sampler = llama_sampler_init_dry(model, context_size, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n);
184184

185185
llama_sampler_dry_set_seq_breakers(dry_sampler, params.dry_sequence_breakers);
186186
llama_sampler_chain_add(result->chain, dry_sampler);
@@ -289,19 +289,19 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
289289

290290
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
291291
// Check and set the context size if it hasn't been set yet
292-
if (!gsmpl->context_size_set) {
293-
gsmpl->n_ctx = llama_n_ctx(ctx);
294-
gsmpl->context_size_set = true;
295-
296-
// Update the DRY sampler's context size if it is active
297-
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
298-
auto * sampler = llama_sampler_chain_get(gsmpl->chain, i);
299-
if (strcmp(llama_sampler_name(sampler), "dry") == 0) {
300-
llama_sampler_dry_set_context_size(sampler, gsmpl->n_ctx);
301-
break;
302-
}
303-
}
304-
}
292+
// if (!gsmpl->context_size_set) {
293+
// gsmpl->n_ctx = llama_n_ctx(ctx);
294+
// gsmpl->context_size_set = true;
295+
296+
// // Update the DRY sampler's context size if it is active
297+
// for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
298+
// auto * sampler = llama_sampler_chain_get(gsmpl->chain, i);
299+
// if (strcmp(llama_sampler_name(sampler), "dry") == 0) {
300+
// llama_sampler_dry_set_context_size(sampler, gsmpl->n_ctx);
301+
// break;
302+
// }
303+
// }
304+
// }
305305

306306
gsmpl->set_logits(ctx, idx);
307307

common/sampling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct gpt_sampler;
3636

3737
// llama_sampler API overloads
3838

39-
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);
39+
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params, int32_t context_size);
4040

4141
void gpt_sampler_free(struct gpt_sampler * gsmpl);
4242

examples/infill/infill.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ int main(int argc, char ** argv) {
298298
LOG_INF("Input suffix: '%s'\n", params.input_suffix.c_str());
299299
}
300300
}
301-
smpl = gpt_sampler_init(model, sparams);
301+
smpl = gpt_sampler_init(model, sparams, n_ctx);
302302

303303
LOG_INF("sampler seed: %u\n", gpt_sampler_get_seed(smpl));
304304
LOG_INF("sampler params: \n%s\n", sparams.print().c_str());

examples/llava/llava-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
191191

192192
LOG("\n");
193193

194-
struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
194+
struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams, llama_n_ctx(ctx_llava->ctx_llama));
195195
if (!smpl) {
196196
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
197197
exit(1);

examples/llava/minicpmv-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ static struct gpt_sampler * llama_init(struct llava_context * ctx_llava, gpt_par
237237

238238
LOG_INF("\n");
239239

240-
struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
240+
struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams, llama_n_ctx(ctx_llava->ctx_llama));
241241
return smpl;
242242
}
243243

examples/lookahead/lookahead.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ int main(int argc, char ** argv) {
115115
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
116116

117117
// target model sampling context
118-
struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
118+
struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams, llama_n_ctx(ctx));
119119

120120
// verification n-grams
121121
std::vector<ngram_data> ngrams_cur(G);

examples/lookup/lookup.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ int main(int argc, char ** argv){
102102

103103
bool has_eos = false;
104104

105-
struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
105+
struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams, max_context_size);
106106

107107
std::vector<llama_token> draft;
108108

examples/main/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ int main(int argc, char ** argv) {
448448
}
449449
}
450450

451-
smpl = gpt_sampler_init(model, sparams);
451+
smpl = gpt_sampler_init(model, sparams, n_ctx);
452452
if (!smpl) {
453453
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
454454
return 1;

examples/parallel/parallel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ int main(int argc, char ** argv) {
160160
for (size_t i = 0; i < clients.size(); ++i) {
161161
auto & client = clients[i];
162162
client.id = i;
163-
client.smpl = gpt_sampler_init(model, params.sparams);
163+
client.smpl = gpt_sampler_init(model, params.sparams, n_ctx);
164164
}
165165

166166
std::vector<llama_token> tokens_system;

examples/server/server.cpp

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ struct server_slot {
163163
int32_t n_prompt_tokens = 0;
164164
int32_t n_prompt_tokens_processed = 0;
165165

166-
json prompt; // can be either a string, array of strings or array of token ids
166+
json prompt; // can be either a string, array of strings, array of token ids, or mixed array of strings and token ids
167167

168168
// when a task is submitted, we first tokenize the prompt and store it here
169169
std::vector<llama_token> prompt_tokens;
@@ -975,16 +975,15 @@ struct server_context {
975975
}
976976

977977
if ((prompt->is_string()) ||
978-
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
979-
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
978+
(prompt->is_array() && !prompt->empty() && (prompt->at(0).is_string() || prompt->at(0).is_number_integer()))) {
980979
slot.prompt = *prompt;
981980
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
982981
slot.prompt = prompt->at(0);
983982
} else if (prompt->is_array() && prompt->size() > 1) {
984983
// array of strings
985984
for (const auto & el : *prompt) {
986985
if (!el.is_string()) {
987-
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
986+
send_error(task, "\"prompt\" must be a string, an array of strings, an array of integers, or a mixed array of strings and integers", ERROR_TYPE_INVALID_REQUEST);
988987
return false;
989988
}
990989
}
@@ -1062,18 +1061,10 @@ struct server_context {
10621061
}
10631062

10641063
{
1065-
// These lines seem to force the clearing of sampler data between generations:
1066-
1067-
// if (slot.smpl != nullptr) {
1068-
// gpt_sampler_free(slot.smpl);
1069-
// }
1070-
// slot.smpl = gpt_sampler_init(model, slot.sparams);
1071-
1072-
// Changed it to this so data could be maintained between generations:
1073-
1074-
if (slot.smpl == nullptr) {
1075-
slot.smpl = gpt_sampler_init(model, slot.sparams);
1064+
if (slot.smpl != nullptr) {
1065+
gpt_sampler_free(slot.smpl);
10761066
}
1067+
slot.smpl = gpt_sampler_init(model, slot.sparams, slot.n_ctx);
10771068

10781069
if (slot.smpl == nullptr) {
10791070
// for now, the only error that may happen here is invalid grammar
@@ -1518,24 +1509,25 @@ struct server_context {
15181509
throw std::runtime_error(error_msg);
15191510
}
15201511
json prompt = data.at("prompt");
1521-
// if the prompt is a singleton (i.e. a string, a list of tokens, or a mixed array of strings and tokens), we only need to create a single task
1522-
if (prompt.is_string() || (prompt.is_array() && !prompt.empty() && !prompt[0].is_array())) {
1523-
bool is_mixed = false;
1524-
bool has_string = prompt.is_string();
1512+
1513+
auto is_valid_singleton_array = [](const json& arr) {
15251514
bool has_number = false;
1526-
if (prompt.is_array()) {
1527-
for (const auto& elem : prompt) {
1528-
if (elem.is_string()) has_string = true;
1529-
else if (elem.is_number()) has_number = true;
1530-
if (has_string && has_number) {
1531-
is_mixed = true;
1532-
break;
1533-
}
1515+
for (const auto& elem : arr) {
1516+
if (elem.is_number()) {
1517+
has_number = true;
1518+
} else if (!elem.is_string()) {
1519+
return false;
15341520
}
15351521
}
1522+
return has_number;
1523+
};
1524+
1525+
bool is_singleton = prompt.is_string() || (prompt.is_array() && is_valid_singleton_array(prompt));
1526+
1527+
// if the prompt is a singleton (i.e. a string, a list of tokens, or a mixed array of strings and tokens), we only need to create a single task
1528+
if (prompt.is_string() || (prompt.is_array() && is_valid_singleton_array(prompt))) {
15361529
data["index"] = 0;
15371530
create_task(data, false, nullptr);
1538-
SRV_DBG("creating single%s prompt task\n", is_mixed ? " mixed" : "");
15391531
}
15401532
// otherwise, it's a multiple-prompt task or a rerank task, we break it into smaller tasks
15411533
else if (prompt.is_array()) {
@@ -2154,7 +2146,8 @@ struct server_context {
21542146
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
21552147
}
21562148

2157-
//gpt_sampler_reset(slot.smpl); // This line is likely preventing sampler state from being maintained from generation to generation
2149+
// Should this be (re-)moved?
2150+
gpt_sampler_reset(slot.smpl);
21582151

21592152
if (!slot.params.cache_prompt) {
21602153
slot.n_past_se = 0;
@@ -2165,10 +2158,13 @@ struct server_context {
21652158
// reuse any previously computed tokens that are common with the new prompt
21662159
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
21672160

2161+
// Not sure if the for loop below should happen in multiple places but for now I moved it
2162+
// until after the entire prompt is processed so that sampling would happen consistently.
2163+
21682164
// push the prompt into the sampling context (do not apply grammar)
2169-
for (int i = 0; i < slot.n_past; ++i) {
2170-
gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
2171-
}
2165+
// for (int i = 0; i < slot.n_past; ++i) {
2166+
// gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
2167+
// }
21722168
}
21732169
}
21742170

@@ -2264,6 +2260,11 @@ struct server_context {
22642260

22652261
GGML_ASSERT(batch.n_tokens > 0);
22662262

2263+
// Process all prompt tokens through sampler system
2264+
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
2265+
gpt_sampler_accept(slot.smpl, prompt_tokens[i], false);
2266+
}
2267+
22672268
// extract the logits only for the last token
22682269
batch.logits[batch.n_tokens - 1] = true;
22692270

0 commit comments

Comments
 (0)