@@ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info {
3333 struct llama_lora_adapter * adapter;
3434};
3535
36+ using llama_tokens = std::vector<llama_token>;
37+
3638// build info
3739extern int LLAMA_BUILD_NUMBER;
3840extern char const * LLAMA_COMMIT;
@@ -101,8 +103,8 @@ enum dimre_method {
101103 DIMRE_METHOD_MEAN,
102104};
103105
104- // sampler parameters
105- struct common_sampler_params {
106+ // sampling parameters
107+ struct common_params_sampling {
106108 uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
107109
108110 int32_t n_prev = 64 ; // number of previous tokens to remember
@@ -153,19 +155,30 @@ struct common_sampler_params {
153155 std::string print () const ;
154156};
155157
158+ struct common_params_speculative {
159+ int32_t n_ctx = 0 ; // draft context size
160+ int32_t n_max = 16 ; // maximum number of tokens to draft during speculative decoding
161+ int32_t n_min = 5 ; // minimum number of draft tokens to use for speculative decoding
162+ int32_t n_gpu_layers = -1 ; // number of layers to store in VRAM for the draft model (-1 - use default)
163+ float p_split = 0 .1f ; // speculative decoding split probability
164+ float p_min = 0 .9f ; // minimum speculative decoding probability (greedy)
165+
166+ struct cpu_params cpuparams;
167+ struct cpu_params cpuparams_batch;
168+
169+ std::string model = " " ; // draft model for speculative decoding // NOLINT
170+ };
171+
156172struct common_params {
157173 int32_t n_predict = -1 ; // new tokens to predict
158174 int32_t n_ctx = 4096 ; // context size
159175 int32_t n_batch = 2048 ; // logical batch size for prompt processing (must be >=32 to use BLAS)
160176 int32_t n_ubatch = 512 ; // physical batch size for prompt processing (must be >=32 to use BLAS)
161177 int32_t n_keep = 0 ; // number of tokens to keep from initial prompt
162- int32_t n_draft = 5 ; // number of tokens to draft during speculative decoding
163178 int32_t n_chunks = -1 ; // max number of chunks to process (-1 = unlimited)
164179 int32_t n_parallel = 1 ; // number of parallel sequences to decode
165180 int32_t n_sequences = 1 ; // number of sequences to decode
166- float p_split = 0 .1f ; // speculative decoding split probability
167181 int32_t n_gpu_layers = -1 ; // number of layers to store in VRAM (-1 - use default)
168- int32_t n_gpu_layers_draft = -1 ; // number of layers to store in VRAM for the draft model (-1 - use default)
169182 int32_t main_gpu = 0 ; // the GPU that is used for scratch and small tensors
170183 float tensor_split[128 ] = {0 }; // how split tensors should be distributed across GPUs
171184 int32_t grp_attn_n = 1 ; // group-attention factor
@@ -182,8 +195,6 @@ struct common_params {
182195
183196 struct cpu_params cpuparams;
184197 struct cpu_params cpuparams_batch;
185- struct cpu_params draft_cpuparams;
186- struct cpu_params draft_cpuparams_batch;
187198
188199 ggml_backend_sched_eval_callback cb_eval = nullptr ;
189200 void * cb_eval_user_data = nullptr ;
@@ -195,10 +206,10 @@ struct common_params {
195206 enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
196207 enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
197208
198- struct common_sampler_params sparams;
209+ struct common_params_sampling sampling;
210+ struct common_params_speculative speculative;
199211
200212 std::string model = " " ; // model path // NOLINT
201- std::string model_draft = " " ; // draft model for speculative decoding // NOLINT
202213 std::string model_alias = " unknown" ; // model alias // NOLINT
203214 std::string model_url = " " ; // model url to download // NOLINT
204215 std::string hf_token = " " ; // HF token // NOLINT
@@ -461,7 +472,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f
461472// clear LoRA adapters from context, then apply new list of adapters
462473void common_lora_adapters_apply (struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
463474
475+ //
464476// Batch utils
477+ //
465478
466479void common_batch_clear (struct llama_batch & batch);
467480
@@ -472,6 +485,16 @@ void common_batch_add(
472485 const std::vector<llama_seq_id> & seq_ids,
473486 bool logits);
474487
488+ //
489+ // Token utils
490+ //
491+
492+ // longest common prefix
493+ size_t common_lcp (const llama_tokens & a, const llama_tokens & b);
494+
495+ // longet common subsequence
496+ size_t common_lcs (const llama_tokens & a, const llama_tokens & b);
497+
475498//
476499// Vocab utils
477500//
0 commit comments