@@ -64,6 +64,33 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
6464 }
6565}
6666
67+ // return a list of splits for a given path
68+ // for example, given "<name>-00002-of-00004.gguf", returns list of all 4 splits
69+ static std::vector<std::string> llama_get_list_splits (const std::string & path, const int idx, const int n_split) {
70+ std::vector<std::string> paths;
71+ std::string split_prefix;
72+ std::vector<char > buf (llama_path_max (), 0 );
73+
74+ {
75+ int ret = llama_split_prefix (buf.data (), buf.size (), path.c_str (), idx, n_split);
76+ if (!ret) {
77+ throw std::runtime_error (format (" invalid split file name: %s" , path.c_str ()));
78+ }
79+ split_prefix = std::string (buf.data (), ret);
80+ }
81+
82+ if (split_prefix.empty ()) {
83+ throw std::runtime_error (format (" invalid split file: %s" , path.c_str ()));
84+ }
85+
86+ for (int idx = 0 ; idx < n_split; ++idx) {
87+ int ret = llama_split_path (buf.data (), buf.size (), split_prefix.c_str (), idx, n_split);
88+ paths.push_back (std::string (buf.data (), ret));
89+ }
90+
91+ return paths;
92+ }
93+
6794namespace GGUFMeta {
6895 template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int64_t )>
6996 struct GKV_Base_Type {
@@ -413,7 +440,12 @@ namespace GGUFMeta {
413440 template bool llama_model_loader::get_key_or_arr<std::array<int , 4 >>(enum llm_kv kid, std::array<int , 4 > & result, uint32_t n, bool required);
414441 template bool llama_model_loader::get_key_or_arr<std::array<uint32_t , 512 >>(enum llm_kv kid, std::array<uint32_t , 512 > & result, uint32_t n, bool required);
415442
416- llama_model_loader::llama_model_loader (const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
443+ llama_model_loader::llama_model_loader (
444+ const std::string & fname,
445+ std::vector<std::string> & splits,
446+ bool use_mmap,
447+ bool check_tensors,
448+ const struct llama_model_kv_override * param_overrides_p) {
417449 int trace = 0 ;
418450 if (getenv (" LLAMA_TRACE" )) {
419451 trace = atoi (getenv (" LLAMA_TRACE" ));
@@ -425,6 +457,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
425457 }
426458 }
427459
460+ // Load the main GGUF
428461 struct ggml_context * ctx = NULL ;
429462 struct gguf_init_params params = {
430463 /* .no_alloc = */ true ,
@@ -460,35 +493,54 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap,
460493
461494 // Load additional GGML contexts
462495 if (n_split > 1 ) {
496+ // make sure the main file is loaded first
463497 uint16_t idx = 0 ;
464- get_key (llm_kv (LLM_KV_SPLIT_NO), idx);
498+ const std::string kv_split_no = llm_kv (LLM_KV_SPLIT_NO);
499+ get_key (kv_split_no, idx);
465500 if (idx != 0 ) {
466- throw std::runtime_error (format (" illegal split file: %d, model must be loaded with the first split" , idx));
501+ throw std::runtime_error (format (" illegal split file idx: %d (file: %s), model must be loaded with the first split" , idx, fname.c_str ()));
502+ }
503+
504+ // generate list of splits if needed
505+ if (splits.empty ()) {
506+ splits = llama_get_list_splits (fname, idx, n_split);
467507 }
468508
469- std::vector< char > split_prefix ( llama_path_max (), 0 );
470- if (! llama_split_prefix (split_prefix. data (), split_prefix .size (), fname. c_str (), idx, n_split )) {
471- throw std::runtime_error (format (" invalid split file : %s " , fname. c_str () ));
509+ // in case user give a custom list of splits, check if it matches the expected number
510+ if (n_split != ( uint16_t )splits .size ()) {
511+ throw std::runtime_error (format (" invalid split count, given : %zu splits, but expected %d " , splits. size (), n_split ));
472512 }
473513
474514 if (trace > 0 ) {
475515 LLAMA_LOG_INFO (" %s: loading additional %d GGUFs\n " , __func__, n_split);
476516 }
477517
478- std::vector< char > split_path ( llama_path_max (), 0 );
518+ // load other splits
479519 for (idx = 1 ; idx < n_split; idx++) {
480- llama_split_path (split_path. data (), split_path. size (), split_prefix. data (), idx, n_split );
520+ const char * fname_split = splits[idx]. c_str ( );
481521
482522 struct gguf_init_params split_params = {
483523 /* .no_alloc = */ true ,
484524 /* .ctx = */ &ctx,
485525 };
486- gguf_context_ptr ctx_gguf { gguf_init_from_file (split_path. data () , split_params) };
526+ gguf_context_ptr ctx_gguf { gguf_init_from_file (fname_split , split_params) };
487527 if (!ctx_gguf) {
488- throw std::runtime_error (format (" %s: failed to load GGUF split from %s\n " , __func__, split_path.data ()));
528+ throw std::runtime_error (format (" %s: failed to load GGUF split from %s\n " , __func__, fname_split));
529+ }
530+
531+ // check idx
532+ {
533+ const int kid = gguf_find_key (ctx_gguf.get (), kv_split_no.c_str ());
534+ if (kid < 0 ) {
535+ throw std::runtime_error (format (" missing key %s in GGUF split %s" , kv_split_no.c_str (), fname_split));
536+ }
537+ int idx_gguf = gguf_get_val_u16 (ctx_gguf.get (), kid);
538+ if (idx_gguf != idx) {
539+ throw std::runtime_error (format (" invalid split file idx: %d (file: %s), expected %d" , idx_gguf, fname_split, idx));
540+ }
489541 }
490542
491- files.emplace_back (new llama_file (split_path. data () , " rb" ));
543+ files.emplace_back (new llama_file (fname_split , " rb" ));
492544 contexts.emplace_back (ctx);
493545
494546 // Save tensors data offset info of the shard.
0 commit comments