@@ -64,6 +64,33 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
6464    }
6565}
6666
67+ //  return a list of splits for a given path
68+ //  for example, given "<name>-00002-of-00004.gguf", returns list of all 4 splits
69+ static  std::vector<std::string> llama_get_list_splits (const  std::string & path, const  int  idx, const  int  n_split) {
70+     std::vector<std::string> paths;
71+     std::string split_prefix;
72+     std::vector<char > buf (llama_path_max (), 0 );
73+ 
74+     {
75+         int  ret = llama_split_prefix (buf.data (), buf.size (), path.c_str (), idx, n_split);
76+         if  (!ret) {
77+             throw  std::runtime_error (format (" invalid split file name: %s"  , path.c_str ()));
78+         }
79+         split_prefix = std::string (buf.data (), ret);
80+     }
81+ 
82+     if  (split_prefix.empty ()) {
83+         throw  std::runtime_error (format (" invalid split file: %s"  , path.c_str ()));
84+     }
85+ 
86+     for  (int  idx = 0 ; idx < n_split; ++idx) {
87+         int  ret = llama_split_path (buf.data (), buf.size (), split_prefix.c_str (), idx, n_split);
88+         paths.push_back (std::string (buf.data (), ret));
89+     }
90+ 
91+     return  paths;
92+ }
93+ 
6794namespace  GGUFMeta  {
6895    template  <typename  T, gguf_type gt_, T (*gfun)(const  gguf_context *, const  int64_t )>
6996    struct  GKV_Base_Type  {
@@ -466,27 +493,29 @@ llama_model_loader::llama_model_loader(
466493
467494    //  Load additional GGML contexts
468495    if  (n_split > 1 ) {
496+         //  make sure the main file is loaded first
497+         uint16_t  idx = 0 ;
498+         const  std::string kv_split_no = llm_kv (LLM_KV_SPLIT_NO);
499+         get_key (kv_split_no, idx);
500+         if  (idx != 0 ) {
501+             throw  std::runtime_error (format (" illegal split file idx: %d (file: %s), model must be loaded with the first split"  , idx, fname.c_str ()));
502+         }
503+ 
469504        //  generate list of splits if needed
470505        if  (splits.empty ()) {
471-             splits = llama_get_list_splits (fname, n_split);
506+             splits = llama_get_list_splits (fname, idx,  n_split);
472507        }
473508
474509        //  in case user give a custom list of splits, check if it matches the expected number
475510        if  (n_split != (uint16_t )splits.size ()) {
476511            throw  std::runtime_error (format (" invalid split count, given: %zu splits, but expected %d"  , splits.size (), n_split));
477512        }
478513
479-         uint16_t  idx = 0 ;
480-         const  std::string kv_split_no = llm_kv (LLM_KV_SPLIT_NO);
481-         get_key (kv_split_no, idx);
482-         if  (idx != 0 ) {
483-             throw  std::runtime_error (format (" illegal split file idx: %d (file: %s), model must be loaded with the first split"  , idx, fname.c_str ()));
484-         }
485- 
486514        if  (trace > 0 ) {
487515            LLAMA_LOG_INFO (" %s: loading additional %d GGUFs\n "  , __func__, n_split);
488516        }
489517
518+         //  load other splits
490519        for  (idx = 1 ; idx < n_split; idx++) {
491520            const  char  * fname_split = splits[idx].c_str ();
492521
@@ -1093,28 +1122,3 @@ void llama_model_loader::print_info() const {
10931122        LLAMA_LOG_INFO (" %s: file size   = %.2f GiB (%.2f BPW) \n "  , __func__, n_bytes/1024.0 /1024.0 /1024.0 , n_bytes*8.0 /n_elements);
10941123    }
10951124}
1096- 
1097- std::vector<std::string> llama_get_list_splits (const  std::string & path, const  int  n_split) {
1098-     std::vector<std::string> paths;
1099-     std::string split_prefix;
1100-     std::vector<char > buf (llama_path_max (), 0 );
1101- 
1102-     //  brute force to find the split prefix
1103-     for  (int  idx = 0 ; idx < n_split; ++idx) {
1104-         int  ret = llama_split_prefix (buf.data (), buf.size (), path.c_str (), idx, n_split);
1105-         if  (ret) {
1106-             split_prefix = std::string (buf.data (), ret);
1107-         }
1108-     }
1109- 
1110-     if  (split_prefix.empty ()) {
1111-         throw  std::runtime_error (format (" invalid split file: %s"  , path.c_str ()));
1112-     }
1113- 
1114-     for  (int  idx = 0 ; idx < n_split; ++idx) {
1115-         int  ret = llama_split_path (buf.data (), buf.size (), split_prefix.c_str (), idx, n_split);
1116-         paths.push_back (std::string (buf.data (), ret));
1117-     }
1118- 
1119-     return  paths;
1120- }
0 commit comments