@@ -413,7 +413,12 @@ namespace GGUFMeta {
413413 template bool llama_model_loader::get_key_or_arr<std::array<int , 4 >>(enum llm_kv kid, std::array<int , 4 > & result, uint32_t n, bool required);
414414 template bool llama_model_loader::get_key_or_arr<std::array<uint32_t , 512 >>(enum llm_kv kid, std::array<uint32_t , 512 > & result, uint32_t n, bool required);
415415
416- llama_model_loader::llama_model_loader (const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
416+ llama_model_loader::llama_model_loader (
417+ const std::string & fname,
418+ std::vector<std::string> & splits,
419+ bool use_mmap,
420+ bool check_tensors,
421+ const struct llama_model_kv_override * param_overrides_p) {
417422 int trace = 0 ;
418423 if (getenv (" LLAMA_TRACE" )) {
419424 trace = atoi (getenv (" LLAMA_TRACE" ));
@@ -425,6 +430,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
425430 }
426431 }
427432
433+ // Load the main GGUF
428434 struct ggml_context * ctx = NULL ;
429435 struct gguf_init_params params = {
430436 /* .no_alloc = */ true ,
@@ -460,35 +466,52 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
460466
461467 // Load additional GGML contexts
462468 if (n_split > 1 ) {
463- uint16_t idx = 0 ;
464- get_key (llm_kv (LLM_KV_SPLIT_NO), idx);
465- if (idx != 0 ) {
466- throw std::runtime_error (format (" illegal split file: %d, model must be loaded with the first split" , idx));
469+ // generate list of splits if needed
470+ if (splits.empty ()) {
471+ splits = llama_get_list_splits (fname, n_split);
467472 }
468473
469- std::vector<char > split_prefix (llama_path_max (), 0 );
470- if (!llama_split_prefix (split_prefix.data (), split_prefix.size (), fname.c_str (), idx, n_split)) {
471- throw std::runtime_error (format (" invalid split file: %s" , fname.c_str ()));
474+ // in case user give a custom list of splits, check if it matches the expected number
475+ if (n_split != (uint16_t )splits.size ()) {
476+ throw std::runtime_error (format (" invalid split count, given: %zu splits, but expected %d" , splits.size (), n_split));
477+ }
478+
479+ uint16_t idx = 0 ;
480+ const std::string kv_split_no = llm_kv (LLM_KV_SPLIT_NO);
481+ get_key (kv_split_no, idx);
482+ if (idx != 0 ) {
483+ throw std::runtime_error (format (" illegal split file idx: %d (file: %s), model must be loaded with the first split" , idx, fname.c_str ()));
472484 }
473485
474486 if (trace > 0 ) {
475487 LLAMA_LOG_INFO (" %s: loading additional %d GGUFs\n " , __func__, n_split);
476488 }
477489
478- std::vector<char > split_path (llama_path_max (), 0 );
479490 for (idx = 1 ; idx < n_split; idx++) {
480- llama_split_path (split_path. data (), split_path. size (), split_prefix. data (), idx, n_split );
491+ const char * fname_split = splits[idx]. c_str ( );
481492
482493 struct gguf_init_params split_params = {
483494 /* .no_alloc = */ true ,
484495 /* .ctx = */ &ctx,
485496 };
486- gguf_context_ptr ctx_gguf { gguf_init_from_file (split_path. data () , split_params) };
497+ gguf_context_ptr ctx_gguf { gguf_init_from_file (fname_split , split_params) };
487498 if (!ctx_gguf) {
488- throw std::runtime_error (format (" %s: failed to load GGUF split from %s\n " , __func__, split_path.data ()));
499+ throw std::runtime_error (format (" %s: failed to load GGUF split from %s\n " , __func__, fname_split));
500+ }
501+
502+ // check idx
503+ {
504+ const int kid = gguf_find_key (ctx_gguf.get (), kv_split_no.c_str ());
505+ if (kid < 0 ) {
506+ throw std::runtime_error (format (" missing key %s in GGUF split %s" , kv_split_no.c_str (), fname_split));
507+ }
508+ int idx_gguf = gguf_get_val_u16 (ctx_gguf.get (), kid);
509+ if (idx_gguf != idx) {
510+ throw std::runtime_error (format (" invalid split file idx: %d (file: %s), expected %d" , idx_gguf, fname_split, idx));
511+ }
489512 }
490513
491- files.emplace_back (new llama_file (split_path. data () , " rb" ));
514+ files.emplace_back (new llama_file (fname_split , " rb" ));
492515 contexts.emplace_back (ctx);
493516
494517 // Save tensors data offset info of the shard.
@@ -1070,3 +1093,28 @@ void llama_model_loader::print_info() const {
10701093 LLAMA_LOG_INFO (" %s: file size = %.2f GiB (%.2f BPW) \n " , __func__, n_bytes/1024.0 /1024.0 /1024.0 , n_bytes*8.0 /n_elements);
10711094 }
10721095}
1096+
1097+ std::vector<std::string> llama_get_list_splits (const std::string & path, const int n_split) {
1098+ std::vector<std::string> paths;
1099+ std::string split_prefix;
1100+ std::vector<char > buf (llama_path_max (), 0 );
1101+
1102+ // brute force to find the split prefix
1103+ for (int idx = 0 ; idx < n_split; ++idx) {
1104+ int ret = llama_split_prefix (buf.data (), buf.size (), path.c_str (), idx, n_split);
1105+ if (ret) {
1106+ split_prefix = std::string (buf.data (), ret);
1107+ }
1108+ }
1109+
1110+ if (split_prefix.empty ()) {
1111+ throw std::runtime_error (format (" invalid split file: %s" , path.c_str ()));
1112+ }
1113+
1114+ for (int idx = 0 ; idx < n_split; ++idx) {
1115+ int ret = llama_split_path (buf.data (), buf.size (), split_prefix.c_str (), idx, n_split);
1116+ paths.push_back (std::string (buf.data (), ret));
1117+ }
1118+
1119+ return paths;
1120+ }
0 commit comments