Skip to content

Commit 20fe00e

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 2a11fb8 + 5d804a4 commit 20fe00e

39 files changed

+1716
-807
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
137137
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
138138
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
139139
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
140+
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
140141

141142
#### Multimodal
142143

ci/run.sh

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -386,10 +386,10 @@ function gg_run_open_llama_7b_v2 {
386386

387387
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
388388

389-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
390-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
391-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
392-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
389+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
390+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
391+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
392+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
393393

394394
function check_ppl {
395395
qnt="$1"
@@ -520,8 +520,8 @@ function gg_run_pythia_1_4b {
520520

521521
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
522522

523-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
524-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
523+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
524+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
525525

526526
function check_ppl {
527527
qnt="$1"
@@ -651,10 +651,10 @@ function gg_run_pythia_2_8b {
651651

652652
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
653653

654-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
655-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
656-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
657-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
654+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
655+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
656+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
657+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
658658

659659
function check_ppl {
660660
qnt="$1"

common/arg.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2962,20 +2962,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
29622962
params.endpoint_metrics = true;
29632963
}
29642964
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS"));
2965-
add_opt(common_arg(
2966-
{"--slots"},
2967-
string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
2968-
[](common_params & params) {
2969-
params.endpoint_slots = true;
2970-
}
2971-
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
29722965
add_opt(common_arg(
29732966
{"--props"},
29742967
string_format("enable changing global properties via POST /props (default: %s)", params.endpoint_props ? "enabled" : "disabled"),
29752968
[](common_params & params) {
29762969
params.endpoint_props = true;
29772970
}
29782971
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_PROPS"));
2972+
add_opt(common_arg(
2973+
{"--slots"},
2974+
string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
2975+
[](common_params & params) {
2976+
params.endpoint_slots = true;
2977+
}
2978+
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
29792979
add_opt(common_arg(
29802980
{"--no-slots"},
29812981
"disables slots monitoring endpoint",

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ struct common_params {
444444

445445
// "advanced" endpoints are disabled by default for better security
446446
bool webui = true;
447-
bool endpoint_slots = false;
447+
bool endpoint_slots = true;
448448
bool endpoint_props = false; // only control POST requests, not GET
449449
bool endpoint_metrics = false;
450450

common/sampling.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,8 +426,29 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
426426

427427
// helpers
428428

429-
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
430-
return &gsmpl->cur_p;
429+
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
430+
auto * res = &gsmpl->cur_p;
431+
432+
if (do_sort && !res->sorted) {
433+
// remember the selected token before sorting
434+
const llama_token id = res->data[res->selected].id;
435+
436+
std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
437+
return a.p > b.p;
438+
});
439+
440+
// restore the selected token after sorting
441+
for (size_t i = 0; i < res->size; ++i) {
442+
if (res->data[i].id == id) {
443+
res->selected = i;
444+
break;
445+
}
446+
}
447+
448+
res->sorted = true;
449+
}
450+
451+
return res;
431452
}
432453

433454
llama_token common_sampler_last(const struct common_sampler * gsmpl) {

common/sampling.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
8686
// helpers
8787

8888
// access the internal list of current candidate tokens
89-
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
89+
// if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
90+
// the .sorted flag of the result indicates whether the returned candidates are sorted
91+
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);
9092

9193
// get the last accepted token
9294
llama_token common_sampler_last(const struct common_sampler * gsmpl);

common/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ llama_tokens common_speculative_gen_draft(
317317

318318
common_sampler_sample(smpl, ctx_dft, 0, true);
319319

320-
const auto * cur_p = common_sampler_get_candidates(smpl);
320+
const auto * cur_p = common_sampler_get_candidates(smpl, true);
321321

322322
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
323323
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",

convert_hf_to_gguf.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,10 +302,6 @@ def prepare_tensors(self):
302302
# data = data_torch.squeeze().numpy()
303303
data = data_torch.numpy()
304304

305-
# if data ends up empty, it means data_torch was a scalar tensor -> restore
306-
if len(data.shape) == 0:
307-
data = data_torch.numpy()
308-
309305
n_dims = len(data.shape)
310306
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
311307

examples/speculative/speculative.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ int main(int argc, char ** argv) {
244244
// stochastic verification
245245
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
246246

247-
auto & dist_tgt = *common_sampler_get_candidates(smpl);
247+
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
248248

249249
float p_tgt = 0.0f;
250250
float p_dft = 0.0f;
@@ -493,7 +493,7 @@ int main(int argc, char ** argv) {
493493

494494
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
495495

496-
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
496+
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
497497

498498
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
499499
LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",

ggml/include/ggml-backend.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ extern "C" {
307307
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
308308
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
309309

310+
// Split graph without allocating it
311+
GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
312+
310313
// Allocate and compute graph on the backend scheduler
311314
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
312315
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);

0 commit comments

Comments
 (0)