Skip to content

Commit d0bea21

Browse files
committed
examples : update batched to use backend sampling
This commit updates the batched example to demonstrate how to use backend samplers.
1 parent 25f3380 commit d0bea21

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

examples/batched/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,15 @@ llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms
4242
llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second)
4343
llama_print_timings: total time = 4156.04 ms
4444
```
45+
46+
### Using backend samplers
47+
It is possible to run this example using backend samplers so that sampling is
48+
performed on the backend device, like a GPU.
49+
```bash
50+
./llama-batched \
51+
-m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf -p "Hello my name is" \
52+
-np 4 -kvu \
53+
--backend_sampling --top-k 80 --backend_dist
54+
```
55+
The `--verbose` flag can be added to see more detailed output and also show
56+
that the backend samplers are being used.

examples/batched/batched.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "common.h"
33
#include "log.h"
44
#include "llama.h"
5+
#include "sampling.h"
56

67
#include <algorithm>
78
#include <cstdio>
@@ -64,6 +65,18 @@ int main(int argc, char ** argv) {
6465
ctx_params.n_ctx = n_kv_req;
6566
ctx_params.n_batch = std::max(n_predict, n_parallel);
6667

68+
std::vector<llama_sampler_seq_config> sampler_configs(n_parallel);
69+
if (params.sampling.backend_sampling) {
70+
for (int32_t i = 0; i < n_parallel; ++i) {
71+
llama_sampler * backend_sampler = common_sampler_backend_init(model, params.sampling);
72+
if (backend_sampler) {
73+
sampler_configs[i] = { i, backend_sampler };
74+
}
75+
}
76+
ctx_params.samplers = sampler_configs.data();
77+
ctx_params.n_samplers = n_parallel;
78+
}
79+
6780
llama_context * ctx = llama_init_from_model(model, ctx_params);
6881

6982
auto sparams = llama_sampler_chain_default_params();

0 commit comments

Comments
 (0)