Skip to content

Commit 9e5e09d

Browse files
committed
sampling : remove backend-dist option (wip)
This commit removes the `--backend-dist` option and instead uses the configured --samplers chain to determine which samplers run on the backend. Backend sampling is still enabled using With `--backend_sampling`, and the sampler chain, either explictly specified using `--samplers` or the default, is automatically analyzed to determine which samplers can run on the backend. The system finds the longest contiguous chain of backend supported samplers from the start of the sampler sequence. For example: * If the chain is `top-k -> temperature -> top-p`, and both `top-k` and `temperature` are backend-supported but `top-p` is not, then `top-k` and `temperature` will run on the backend, while `top-p` and subsequent samplers run on the CPU. * If all configured samplers are supported, the final distribution sampling will also happen on the backend, transferring only the sampled token IDs back to the host. * If the sampler chain starts with an unsupported sampler (e.g., `penalties`), all sampling runs on the CPU. Note that this is currently the case with the default sampler so to use backend sampling it is required to specify a sampler chain. See below for an example. The following shows how llama-cli can be run with backend sampling: ```console $ llama-cli -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \ --prompt 'What is the capital of Sweden?' \ -n 20 \ -no-cnv \ --verbose-prompt \ -ngl 40 \ --backend-sampling \ --samplers 'top_k;temperature' ``` In this case the all sampling will happen on the backend since both `top_k` and `temperature` are supported backend samplers. To enable a partial backend sampling (hybrid sampling), for example running `top_k` and `temperature` on the backend and `typ_p` on the CPU the following sampler chain could be specified: ```console $ llama-cli -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \ --prompt 'What is the capital of Sweden?' \ -n 20 \ -no-cnv \ --verbose-prompt \ -ngl 40 \ --backend-sampling \ --samplers 'top_k;temperature;top_p' ``` If this looks good then I'll follow up with updates the llama-cli and llama-server documentation to reflect these changes.
1 parent 53dca56 commit 9e5e09d

File tree

15 files changed

+214
-96
lines changed

15 files changed

+214
-96
lines changed

common/arg.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,14 +1520,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
15201520
params.sampling.backend_sampling = true;
15211521
}
15221522
).set_sparam());
1523-
add_opt(common_arg(
1524-
{"--backend-dist"},
1525-
"perform final (distribution) sampling on backend (default: disabled)",
1526-
[](common_params & params) {
1527-
params.sampling.backend_dist = true;
1528-
params.sampling.backend_sampling = true;
1529-
}
1530-
).set_sparam());
15311523
add_opt(common_arg(
15321524
{"--pooling"}, "{none,mean,cls,last,rank}",
15331525
"pooling type for embeddings, use model default if unspecified",

common/common.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,12 +1021,17 @@ struct common_init_result common_init_from_params(common_params & params) {
10211021

10221022
// backend sampling initialization
10231023
if (params.sampling.backend_sampling) {
1024-
iparams.samplers_seq_config.resize(cparams.n_seq_max);
1025-
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
1026-
iparams.samplers_seq_config[i] = { i, common_sampler_backend_init(model, params.sampling) };
1024+
llama_sampler * backend_chain = common_sampler_backend_init(model, params.sampling);
1025+
if (backend_chain != nullptr) {
1026+
iparams.samplers_seq_config.resize(cparams.n_seq_max);
1027+
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
1028+
iparams.samplers_seq_config[i] = { i, llama_sampler_clone(backend_chain) };
1029+
}
1030+
cparams.samplers = iparams.samplers_seq_config.data();
1031+
cparams.n_samplers = cparams.n_seq_max;
1032+
1033+
llama_sampler_free(backend_chain);
10271034
}
1028-
cparams.samplers = iparams.samplers_seq_config.data();
1029-
cparams.n_samplers = cparams.n_seq_max;
10301035
}
10311036

10321037
llama_context * lctx = llama_init_from_model(model, cparams);

common/common.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,7 @@ struct common_params_sampling {
213213
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
214214
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
215215

216-
// Backend sampling flags
217216
bool backend_sampling = false; // enable backend sampling
218-
bool backend_dist = false; // backend performs final sampling (dist)
219217

220218
// print the parameters into a string
221219
std::string print() const;

common/sampling.cpp

Lines changed: 163 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ struct common_sampler {
105105
common_params_sampling params;
106106

107107
struct llama_sampler * grmr;
108-
struct llama_sampler * chain;
108+
struct llama_sampler * chain; // CPU sampling chain
109+
struct llama_sampler * backend_chain; // Backend sampling chain
109110

110111
ring_buffer<llama_token> prev;
111112

@@ -118,6 +119,9 @@ struct common_sampler {
118119

119120
llama_sampler_reset(grmr);
120121
llama_sampler_reset(chain);
122+
if (backend_chain) {
123+
llama_sampler_reset(backend_chain);
124+
}
121125
}
122126

123127
void set_logits(struct llama_context * ctx, int idx) {
@@ -165,6 +169,20 @@ static bool sampler_enabled(const struct common_params_sampling & params, enum c
165169
return std::find(params.samplers.begin(), params.samplers.end(), type) != params.samplers.end();
166170
}
167171

172+
static bool sampler_backend_supported(enum common_sampler_type type) {
173+
switch (type) {
174+
case COMMON_SAMPLER_TYPE_TOP_K:
175+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
176+
return true;
177+
default:
178+
return false;
179+
}
180+
}
181+
182+
static bool has_logit_bias(const struct common_params_sampling & params) {
183+
return !params.logit_bias.empty();
184+
}
185+
168186
std::string common_params_sampling::print() const {
169187
char result[1024];
170188

@@ -249,22 +267,86 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
249267
}
250268

251269
auto * result = new common_sampler {
252-
/* .params = */ params,
253-
/* .grmr = */ grmr,
254-
/* .chain = */ llama_sampler_chain_init(lparams),
255-
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
256-
/* .cur = */ {},
257-
/* .cur_p = */ {},
270+
/* .params = */ params,
271+
/* .grmr = */ grmr,
272+
/* .chain = */ llama_sampler_chain_init(lparams),
273+
/* .backend_chain = */ nullptr,
274+
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
275+
/* .cur = */ {},
276+
/* .cur_p = */ {},
258277
};
259278

260-
llama_sampler_chain_add(result->chain,
261-
llama_sampler_init_logit_bias(
262-
llama_vocab_n_tokens(vocab),
263-
params.logit_bias.size(),
264-
params.logit_bias.data()));
279+
size_t backend_sampler_count = 0;
280+
if (params.backend_sampling && params.mirostat == 0) {
281+
if (has_logit_bias(params)) {
282+
backend_sampler_count++;
283+
}
284+
285+
// Find the longest contiguous chain of backend-supported samplers from the start
286+
for (const auto & sampler_type : params.samplers) {
287+
if (sampler_backend_supported(sampler_type)) {
288+
backend_sampler_count++;
289+
} else {
290+
break;
291+
}
292+
}
293+
}
294+
295+
// If the samplers combination is supported then we can build the backend chain.
296+
if (backend_sampler_count > 0 || (params.backend_sampling && has_logit_bias(params))) {
297+
llama_sampler_chain_params backend_params = llama_sampler_chain_default_params();
298+
backend_params.no_perf = params.no_perf;
299+
result->backend_chain = llama_sampler_chain_init(backend_params);
300+
301+
if (has_logit_bias(params)) {
302+
llama_sampler_chain_add(result->backend_chain,
303+
llama_sampler_backend_init_logit_bias(
304+
llama_vocab_n_tokens(vocab),
305+
params.logit_bias.size(),
306+
params.logit_bias.data()));
307+
}
308+
309+
size_t backend_idx = 0;
310+
for (const auto & sampler_type : params.samplers) {
311+
if (backend_idx >= backend_sampler_count - has_logit_bias(params)) {
312+
break;
313+
}
314+
315+
switch (sampler_type) {
316+
case COMMON_SAMPLER_TYPE_TOP_K:
317+
if (params.top_k > 0) {
318+
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_top_k(params.top_k));
319+
}
320+
backend_idx++;
321+
break;
322+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
323+
if (params.temp > 0.0f) {
324+
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_temp(params.temp));
325+
}
326+
backend_idx++;
327+
break;
328+
default:
329+
GGML_ASSERT(false && "unsupported backend sampler");
330+
}
331+
}
332+
}
333+
334+
size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params);
335+
bool cpu_has_samplers = cpu_start_idx < params.samplers.size();
336+
337+
// Build CPU chain
338+
if (!params.backend_sampling || !has_logit_bias(params)) {
339+
llama_sampler_chain_add(result->chain,
340+
llama_sampler_init_logit_bias(
341+
llama_vocab_n_tokens(vocab),
342+
params.logit_bias.size(),
343+
params.logit_bias.data()));
344+
}
265345

266346
if (params.mirostat == 0) {
267-
for (const auto & cnstr : params.samplers) {
347+
// Add remaining CPU samplers
348+
for (size_t i = cpu_start_idx; i < params.samplers.size(); i++) {
349+
const auto & cnstr = params.samplers[i];
268350
switch (cnstr) {
269351
case COMMON_SAMPLER_TYPE_DRY:
270352
{
@@ -308,7 +390,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
308390
GGML_ASSERT(false && "unknown sampler type");
309391
}
310392
}
311-
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
393+
394+
// If all samplers are on backend, add dist to backend; otherwise add to CPU
395+
if (result->backend_chain && !cpu_has_samplers) {
396+
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_dist(params.seed));
397+
} else {
398+
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
399+
}
312400
} else if (params.mirostat == 1) {
313401
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
314402
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
@@ -323,36 +411,74 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
323411
}
324412

325413
struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) {
326-
if (!params.backend_sampling) {
414+
if (!params.backend_sampling || params.mirostat != 0) {
327415
return nullptr;
328416
}
417+
329418
const llama_vocab * vocab = llama_model_get_vocab(model);
330419

420+
// Determine the split point for backend sampling using the same logic as common_sampler_init
421+
size_t backend_sampler_count = 0;
422+
if (has_logit_bias(params)) {
423+
backend_sampler_count++;
424+
}
425+
426+
// Find the longest contiguous chain of backend-supported samplers from the start
427+
for (const auto & sampler_type : params.samplers) {
428+
if (sampler_backend_supported(sampler_type)) {
429+
backend_sampler_count++;
430+
} else {
431+
break;
432+
}
433+
}
434+
435+
if (backend_sampler_count == 0 && !has_logit_bias(params)) {
436+
return nullptr;
437+
}
438+
331439
llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
332440
chain_params.no_perf = params.no_perf;
333441

334442
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
335443

336-
const bool enable_temp = params.temp > 0.0f && sampler_enabled(params, COMMON_SAMPLER_TYPE_TEMPERATURE);
337-
const bool enable_top_k = params.top_k > 0 && sampler_enabled(params, COMMON_SAMPLER_TYPE_TOP_K);
338-
const bool enable_dist = params.backend_dist;
339-
340-
if (!params.logit_bias.empty()) {
444+
// Add logit_bias to backend chain if present
445+
if (has_logit_bias(params)) {
341446
llama_sampler_chain_add(chain, llama_sampler_backend_init_logit_bias(
342447
llama_vocab_n_tokens(vocab),
343448
params.logit_bias.size(),
344449
params.logit_bias.data()));
345450
}
346451

347-
if (enable_temp) {
348-
llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp));
349-
}
452+
size_t backend_idx = 0;
453+
for (const auto & sampler_type : params.samplers) {
454+
if (backend_idx >= backend_sampler_count - has_logit_bias(params)) {
455+
break;
456+
}
350457

351-
if (enable_top_k) {
352-
llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k));
458+
switch (sampler_type) {
459+
case COMMON_SAMPLER_TYPE_TOP_K:
460+
if (params.top_k > 0) {
461+
llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k));
462+
}
463+
backend_idx++;
464+
break;
465+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
466+
if (params.temp > 0.0f) {
467+
llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp));
468+
}
469+
backend_idx++;
470+
break;
471+
default:
472+
GGML_ASSERT(false && "unsupported backend sampler");
473+
}
353474
}
354475

355-
if (enable_dist) {
476+
// Determine if we should add dist sampler to backend chain
477+
// Only add it if all samplers from params.samplers are on the backend
478+
size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params);
479+
bool cpu_has_samplers = cpu_start_idx < params.samplers.size();
480+
481+
if (!cpu_has_samplers) {
356482
llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed));
357483
}
358484

@@ -362,9 +488,12 @@ struct llama_sampler * common_sampler_backend_init(const struct llama_model * mo
362488
void common_sampler_free(struct common_sampler * gsmpl) {
363489
if (gsmpl) {
364490
llama_sampler_free(gsmpl->grmr);
365-
366491
llama_sampler_free(gsmpl->chain);
367492

493+
if (gsmpl->backend_chain) {
494+
llama_sampler_free(gsmpl->backend_chain);
495+
}
496+
368497
delete gsmpl;
369498
}
370499
}
@@ -387,12 +516,13 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
387516

388517
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
389518
return new common_sampler {
390-
/* .params = */ gsmpl->params,
391-
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
392-
/* .chain = */ llama_sampler_clone(gsmpl->chain),
393-
/* .prev = */ gsmpl->prev,
394-
/* .cur = */ gsmpl->cur,
395-
/* .cur_p = */ gsmpl->cur_p,
519+
/* .params = */ gsmpl->params,
520+
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
521+
/* .chain = */ llama_sampler_clone(gsmpl->chain),
522+
/* .backend_chain = */ gsmpl->backend_chain ? llama_sampler_clone(gsmpl->backend_chain) : nullptr,
523+
/* .prev = */ gsmpl->prev,
524+
/* .cur = */ gsmpl->cur,
525+
/* .cur_p = */ gsmpl->cur_p,
396526
};
397527
}
398528

examples/batched/README.md

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,38 @@ llama_print_timings: total time = 4156.04 ms
4545

4646
### Using backend samplers
4747
It is possible to run this example using backend samplers so that sampling is
48-
performed on the backend device, like a GPU.
48+
performed on a backend device, like a GPU.
4949
```bash
5050
./llama-batched \
5151
-m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf -p "Hello my name is" \
52-
-np 4 -kvu \
53-
--backend_sampling --top-k 80 --backend_dist
52+
-np 4 \
53+
-kvu \
54+
--backend_sampling \
55+
--samplers 'top_k;temperature' \
56+
--top-k 80
5457
```
55-
The `--verbose` flag can be added to see more detailed output and also show
56-
that the backend samplers are being used. The above example will perform distribution
57-
sampling on the backend device and only transfer the sampled token ids back to the host.
58+
The samplers specified with `--samplers` must be supported by the backend and
59+
this is why we are explicitly specifying only `top_k` and `temperature` here as
60+
at the time of writing these are supported.
5861

59-
It is also possible to perform partial sampling on the backend, and then allow CPU samplers
60-
to process those results further. This is sometimes referred to as hybrid sampling.
61-
For an example of this we can remove `--backend_dist` from the above command:
62-
```bash
63-
./llama-batched \
64-
-m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf -p "Hello my name is" \
65-
-np 4 -kvu \
66-
--backend_sampling --top-k 80 -v
67-
```
68-
This will perform the top-k filtering on the backend device, and then transfer the filtered logits
69-
back to the host for sampling.
62+
The `--verbose` flag can be added to see more detailed output and also show
63+
that the backend samplers are being used.
64+
65+
With `--backend_sampling` enabled, the sampler chain is automatically analyzed
66+
to determine which samplers can run on the backend. The system finds the longest
67+
contiguous chain of backend-supported samplers from the start of the sampler
68+
sequence. For example:
69+
* If the chain is `top-k -> temperature -> top-p`, and both `top-k` and
70+
`temperature` are backend-supported but `top-p` is not, then `top-k` and
71+
`temperature` will run on the backend, while `top-p` and subsequent samplers
72+
run on the CPU.
73+
* If all configured samplers are supported, the final distribution sampling will
74+
also happen on the backend, transferring only the sampled token IDs back to the
75+
host.
76+
* If the sampler chain starts with an unsupported sampler (e.g., `penalties`),
77+
all sampling runs on the CPU.
78+
79+
**Note:** The default sampler chain includes `penalties` as the first sampler,
80+
which is not backend-supported yet. To use backend sampling, you must explicitly
81+
configure a sampler chain that starts with backend-supported samplers using
82+
`--samplers` like shown above.

tools/server/public/index.html.gz

3.12 KB
Binary file not shown.

0 commit comments

Comments
 (0)