Skip to content

Commit 7f9cc20

Browse files
committed
common : refactor args
ggml-ci
1 parent c8880e7 commit 7f9cc20

File tree

22 files changed

+330
-294
lines changed

22 files changed

+330
-294
lines changed

common/arg.cpp

Lines changed: 231 additions & 214 deletions
Large diffs are not rendered by default.

common/common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -925,9 +925,9 @@ struct common_init_result common_init_from_params(common_params & params) {
925925
common_lora_adapters_apply(lctx, iparams.lora_adapters);
926926
}
927927

928-
if (params.sparams.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
928+
if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
929929
LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
930-
params.sparams.ignore_eos = false;
930+
params.sampling.ignore_eos = false;
931931
}
932932

933933
if (params.warmup) {

common/common.h

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ enum dimre_method {
103103
DIMRE_METHOD_MEAN,
104104
};
105105

106-
// sampler parameters
107-
struct common_sampler_params {
106+
// sampling parameters
107+
struct common_params_sampling {
108108
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
109109

110110
int32_t n_prev = 64; // number of previous tokens to remember
@@ -155,20 +155,30 @@ struct common_sampler_params {
155155
std::string print() const;
156156
};
157157

158+
struct common_params_speculative {
159+
int32_t n_ctx = 4096; // draft context size
160+
int32_t n_max = 5; // maximum number of tokens to draft during speculative decoding
161+
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
162+
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
163+
float p_split = 0.1f; // speculative decoding split probability
164+
float p_min = 0.9f; // minimum speculative decoding probability (greedy)
165+
166+
struct cpu_params cpuparams;
167+
struct cpu_params cpuparams_batch;
168+
169+
std::string model = ""; // draft model for speculative decoding // NOLINT
170+
};
171+
158172
struct common_params {
159173
int32_t n_predict = -1; // new tokens to predict
160174
int32_t n_ctx = 4096; // context size
161175
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
162176
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
163177
int32_t n_keep = 0; // number of tokens to keep from initial prompt
164-
int32_t n_draft = 5; // number of tokens to draft during speculative decoding
165-
int32_t n_draft_min = 0; // minimum number of draft tokens to use for speculative decoding
166178
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
167179
int32_t n_parallel = 1; // number of parallel sequences to decode
168180
int32_t n_sequences = 1; // number of sequences to decode
169-
float p_split = 0.1f; // speculative decoding split probability
170181
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
171-
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
172182
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
173183
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
174184
int32_t grp_attn_n = 1; // group-attention factor
@@ -185,8 +195,6 @@ struct common_params {
185195

186196
struct cpu_params cpuparams;
187197
struct cpu_params cpuparams_batch;
188-
struct cpu_params draft_cpuparams;
189-
struct cpu_params draft_cpuparams_batch;
190198

191199
ggml_backend_sched_eval_callback cb_eval = nullptr;
192200
void * cb_eval_user_data = nullptr;
@@ -198,10 +206,10 @@ struct common_params {
198206
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
199207
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
200208

201-
struct common_sampler_params sparams;
209+
struct common_params_sampling sampling;
210+
struct common_params_speculative speculative;
202211

203212
std::string model = ""; // model path // NOLINT
204-
std::string model_draft = ""; // draft model for speculative decoding // NOLINT
205213
std::string model_alias = "unknown"; // model alias // NOLINT
206214
std::string model_url = ""; // model url to download // NOLINT
207215
std::string hf_token = ""; // HF token // NOLINT

common/sampling.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ struct ring_buffer {
9999
};
100100

101101
struct common_sampler {
102-
common_sampler_params params;
102+
common_params_sampling params;
103103

104104
struct llama_sampler * grmr;
105105
struct llama_sampler * chain;
@@ -125,7 +125,7 @@ struct common_sampler {
125125
}
126126
};
127127

128-
std::string common_sampler_params::print() const {
128+
std::string common_params_sampling::print() const {
129129
char result[1024];
130130

131131
snprintf(result, sizeof(result),
@@ -141,7 +141,7 @@ std::string common_sampler_params::print() const {
141141
return std::string(result);
142142
}
143143

144-
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) {
144+
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
145145
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
146146

147147
lparams.no_perf = params.no_perf;

common/sampling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct common_sampler;
3636

3737
// llama_sampler API overloads
3838

39-
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params);
39+
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
4040

4141
void common_sampler_free(struct common_sampler * gsmpl);
4242

common/speculative.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,32 +29,32 @@ struct common_speculative * common_speculative_init(
2929
// TODO: optimize or pass from outside?
3030
#if 0
3131
{
32-
common_sampler_params sparams;
33-
sparams.no_perf = false;
32+
common_params_sampling params;
33+
params.no_perf = false;
3434

35-
sparams.top_k = 40;
36-
sparams.top_p = 0.9;
35+
params.top_k = 40;
36+
params.top_p = 0.9;
3737

38-
sparams.samplers = {
38+
params.samplers = {
3939
COMMON_SAMPLER_TYPE_TOP_K,
4040
COMMON_SAMPLER_TYPE_TOP_P,
4141
COMMON_SAMPLER_TYPE_INFILL,
4242
};
4343

44-
result->smpl = common_sampler_init(llama_get_model(ctx_dft), sparams);
44+
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
4545
}
4646
#else
4747
{
48-
common_sampler_params sparams;
49-
sparams.no_perf = false;
48+
common_params_sampling params;
49+
params.no_perf = false;
5050

51-
sparams.top_k = 10;
51+
params.top_k = 10;
5252

53-
sparams.samplers = {
53+
params.samplers = {
5454
COMMON_SAMPLER_TYPE_TOP_K,
5555
};
5656

57-
result->smpl = common_sampler_init(llama_get_model(ctx_dft), sparams);
57+
result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
5858
}
5959
#endif
6060

examples/batched/batched.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,10 @@ int main(int argc, char ** argv) {
6868

6969
llama_sampler * smpl = llama_sampler_chain_init(sparams);
7070

71-
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k));
72-
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep));
73-
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp));
74-
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sparams.seed));
71+
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
72+
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
73+
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
74+
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
7575

7676
if (ctx == NULL) {
7777
LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);

examples/infill/infill.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
7373

7474
common_init();
7575

76-
auto & sparams = params.sparams;
76+
auto & sparams = params.sampling;
7777

7878
console::init(params.simple_io, params.use_color);
7979
atexit([]() { console::cleanup(); });

examples/llava/llava-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
191191

192192
LOG("\n");
193193

194-
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams);
194+
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
195195
if (!smpl) {
196196
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
197197
exit(1);

examples/llava/minicpmv-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ static struct common_sampler * llama_init(struct llava_context * ctx_llava, comm
237237

238238
LOG_INF("\n");
239239

240-
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams);
240+
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
241241
return smpl;
242242
}
243243

0 commit comments

Comments
 (0)