Skip to content

Commit 5bbf90d

Browse files
hongkongkiwiclaude
andcommitted
Add split model loading support
This commit introduces comprehensive support for loading models from multiple split files: - Added `load_from_splits()` method to LlamaModel for loading models split across multiple files - Added utility functions `split_path()` and `split_prefix()` for working with split file naming conventions - Added split_model example demonstrating usage of the split loading functionality - Updated workspace Cargo.toml to include the new split_model example This feature enables loading very large models that have been split due to filesystem limitations or distribution requirements. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent a6c142a commit 5bbf90d

File tree

4 files changed

+374
-0
lines changed

4 files changed

+374
-0
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ members = [
77
"examples/simple",
88
"examples/reranker",
99
"examples/mtmd",
10+
"examples/split_model",
11+
"examples/rpc",
1012
]
1113

1214
[workspace.dependencies]

examples/split_model/Cargo.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[package]
2+
name = "split_model"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[dependencies]
7+
llama-cpp-2 = { path = "../../llama-cpp-2" }
8+
anyhow = "1.0"
9+
clap = { version = "4", features = ["derive"] }

examples/split_model/src/main.rs

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
//! Example demonstrating how to load split GGUF models.
2+
//!
3+
//! This example shows how to:
4+
//! - Load a model split across multiple files
5+
//! - Use utility functions to work with split file naming conventions
6+
//! - Generate text from a split model
7+
8+
use anyhow::Result;
9+
use clap::Parser;
10+
use llama_cpp_2::{
11+
context::params::LlamaContextParams,
12+
llama_backend::LlamaBackend,
13+
llama_batch::LlamaBatch,
14+
model::{params::LlamaModelParams, AddBos, LlamaModel},
15+
sampling::LlamaSampler,
16+
};
17+
use std::io::{self, Write};
18+
use std::num::NonZeroU32;
19+
use std::path::{Path, PathBuf};
20+
21+
/// Command line arguments for the split model example
22+
#[derive(Parser, Debug)]
23+
#[command(author, version, about, long_about = None)]
24+
struct Args {
25+
/// Paths to the split model files (can be specified multiple times)
26+
#[arg(short = 'm', long = "model", required = true, num_args = 1..)]
27+
model_paths: Vec<PathBuf>,
28+
29+
/// Alternatively, provide a prefix and the program will auto-detect splits
30+
#[arg(short = 'p', long = "prefix", conflicts_with = "model_paths")]
31+
prefix: Option<String>,
32+
33+
/// Number of splits (required if using --prefix)
34+
#[arg(short = 'n', long = "num-splits", requires = "prefix")]
35+
num_splits: Option<u32>,
36+
37+
/// Prompt to use for generation
38+
#[arg(short = 't', long = "prompt", default_value = "Once upon a time")]
39+
prompt: String,
40+
41+
/// Number of tokens to generate
42+
#[arg(short = 'g', long = "n-predict", default_value_t = 128)]
43+
n_predict: i32,
44+
45+
/// Number of GPU layers
46+
#[arg(short = 'l', long = "n-gpu-layers", default_value_t = 0)]
47+
n_gpu_layers: u32,
48+
49+
/// Context size
50+
#[arg(short = 'c', long = "ctx-size", default_value_t = 2048)]
51+
ctx_size: u32,
52+
53+
/// Temperature for sampling
54+
#[arg(long = "temp", default_value_t = 0.8)]
55+
temperature: f32,
56+
57+
/// Top-P for sampling
58+
#[arg(long = "top-p", default_value_t = 0.95)]
59+
top_p: f32,
60+
61+
/// Seed for random number generation
62+
#[arg(long = "seed", default_value_t = 1234)]
63+
seed: u32,
64+
}
65+
66+
fn main() -> Result<()> {
67+
let args = Args::parse();
68+
69+
// Determine the model paths
70+
let model_paths = if let Some(prefix) = args.prefix {
71+
let num_splits = args.num_splits.expect("num-splits required with prefix");
72+
73+
// Generate split paths using the utility function
74+
let mut paths = Vec::new();
75+
for i in 1..=num_splits {
76+
let path = LlamaModel::split_path(&prefix, i as i32, num_splits as i32);
77+
paths.push(PathBuf::from(path));
78+
}
79+
80+
println!("Generated split paths:");
81+
for path in &paths {
82+
println!(" - {}", path.display());
83+
}
84+
85+
paths
86+
} else {
87+
args.model_paths
88+
};
89+
90+
// Verify all split files exist
91+
for path in &model_paths {
92+
if !path.exists() {
93+
eprintln!("Error: Split file not found: {}", path.display());
94+
std::process::exit(1);
95+
}
96+
}
97+
98+
println!("Loading model from {} splits...", model_paths.len());
99+
100+
// Initialize the backend
101+
let backend = LlamaBackend::init()?;
102+
103+
// Set up model parameters
104+
let mut model_params = LlamaModelParams::default();
105+
if args.n_gpu_layers > 0 {
106+
model_params = model_params.with_n_gpu_layers(args.n_gpu_layers);
107+
}
108+
109+
// Load the model from splits
110+
let model = LlamaModel::load_from_splits(&backend, &model_paths, &model_params)?;
111+
println!("Model loaded successfully!");
112+
113+
// Get model info
114+
let n_vocab = model.n_vocab();
115+
println!("Model vocabulary size: {}", n_vocab);
116+
117+
// Create context
118+
let ctx_params = LlamaContextParams::default()
119+
.with_n_ctx(Some(NonZeroU32::new(args.ctx_size).unwrap()));
120+
121+
let mut ctx = model.new_context(&backend, ctx_params)?;
122+
println!("Context created with size: {}", args.ctx_size);
123+
124+
// Tokenize the prompt
125+
let tokens = model.str_to_token(&args.prompt, AddBos::Always)?;
126+
println!("Prompt tokenized into {} tokens", tokens.len());
127+
128+
// Create batch
129+
let mut batch = LlamaBatch::new(512, 1);
130+
131+
// Add tokens to batch
132+
let last_index = tokens.len() - 1;
133+
for (i, token) in tokens.iter().enumerate() {
134+
let is_last = i == last_index;
135+
batch.add(*token, i as i32, &[0], is_last)?;
136+
}
137+
138+
// Decode the batch
139+
ctx.decode(&mut batch)?;
140+
println!("Initial prompt processed");
141+
142+
// Set up sampling
143+
let mut sampler = LlamaSampler::chain_simple([
144+
LlamaSampler::temp(args.temperature),
145+
LlamaSampler::top_p(args.top_p, 1),
146+
]);
147+
148+
// Generate text
149+
print!("{}", args.prompt);
150+
io::stdout().flush()?;
151+
152+
let mut n_cur = batch.n_tokens();
153+
let mut n_decode = 0;
154+
155+
while n_decode < args.n_predict {
156+
// Sample the next token
157+
let new_token = sampler.sample(&ctx, batch.n_tokens() - 1);
158+
sampler.accept(new_token);
159+
160+
// Check for EOS
161+
if model.is_eog_token(new_token) {
162+
println!();
163+
break;
164+
}
165+
166+
// Print the token
167+
let piece = model.token_to_str(new_token, llama_cpp_2::model::Special::Tokenize)?;
168+
print!("{}", piece);
169+
io::stdout().flush()?;
170+
171+
// Prepare the next batch
172+
batch.clear();
173+
batch.add(new_token, n_cur, &[0], true)?;
174+
n_cur += 1;
175+
176+
// Decode
177+
ctx.decode(&mut batch)?;
178+
n_decode += 1;
179+
}
180+
181+
println!("\n\nGeneration complete!");
182+
println!("Generated {} tokens", n_decode);
183+
184+
// Demonstrate the split_prefix utility
185+
if let Some(first_path) = model_paths.first() {
186+
if let Some(path_str) = first_path.to_str() {
187+
// Try to extract the prefix from the first split file
188+
if let Some(prefix) = LlamaModel::split_prefix(path_str, 1, model_paths.len() as i32) {
189+
println!("\nExtracted prefix from first split: {}", prefix);
190+
}
191+
}
192+
}
193+
194+
Ok(())
195+
}

llama-cpp-2/src/model.rs

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,174 @@ impl LlamaModel {
622622
Ok(LlamaModel { model })
623623
}
624624

625+
/// Load a model from multiple split files.
626+
///
627+
/// This function loads a model that has been split across multiple files. This is useful for
628+
/// very large models that exceed filesystem limitations or need to be distributed across
629+
/// multiple storage devices.
630+
///
631+
/// # Arguments
632+
///
633+
/// * `paths` - A slice of paths to the split model files
634+
/// * `params` - The model parameters
635+
///
636+
/// # Errors
637+
///
638+
/// Returns an error if:
639+
/// - Any of the paths cannot be converted to a C string
640+
/// - The model fails to load from the splits
641+
/// - Any path doesn't exist or isn't accessible
642+
///
643+
/// # Example
644+
///
645+
/// ```no_run
646+
/// use llama_cpp_2::model::{LlamaModel, params::LlamaModelParams};
647+
/// use llama_cpp_2::llama_backend::LlamaBackend;
648+
/// use std::path::Path;
649+
///
650+
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
651+
/// let backend = LlamaBackend::init()?;
652+
/// let params = LlamaModelParams::default();
653+
///
654+
/// let paths = vec![
655+
/// Path::new("model-00001-of-00003.gguf"),
656+
/// Path::new("model-00002-of-00003.gguf"),
657+
/// Path::new("model-00003-of-00003.gguf"),
658+
/// ];
659+
///
660+
/// let model = LlamaModel::load_from_splits(&backend, &paths, &params)?;
661+
/// # Ok(())
662+
/// # }
663+
/// ```
664+
#[tracing::instrument(skip_all)]
665+
pub fn load_from_splits(
666+
_: &LlamaBackend,
667+
paths: &[impl AsRef<Path>],
668+
params: &LlamaModelParams,
669+
) -> Result<Self, LlamaModelLoadError> {
670+
// Convert paths to C strings
671+
let c_strings: Vec<CString> = paths
672+
.iter()
673+
.map(|p| {
674+
let path = p.as_ref();
675+
debug_assert!(path.exists(), "{path:?} does not exist");
676+
let path_str = path
677+
.to_str()
678+
.ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
679+
CString::new(path_str).map_err(LlamaModelLoadError::from)
680+
})
681+
.collect::<Result<Vec<_>, _>>()?;
682+
683+
// Create array of pointers to C strings
684+
let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
685+
686+
// Load the model from splits
687+
let llama_model = unsafe {
688+
llama_cpp_sys_2::llama_model_load_from_splits(
689+
c_ptrs.as_ptr() as *mut *const c_char,
690+
c_ptrs.len(),
691+
params.params,
692+
)
693+
};
694+
695+
let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
696+
697+
tracing::debug!("Loaded model from {} splits", paths.len());
698+
Ok(LlamaModel { model })
699+
}
700+
701+
/// Build a split GGUF file path for a specific chunk.
702+
///
703+
/// This utility function creates the standardized filename for a split model chunk
704+
/// following the pattern: `{prefix}-{split_no:05d}-of-{split_count:05d}.gguf`
705+
///
706+
/// # Arguments
707+
///
708+
/// * `path_prefix` - The base path and filename prefix
709+
/// * `split_no` - The split number (1-indexed)
710+
/// * `split_count` - The total number of splits
711+
///
712+
/// # Returns
713+
///
714+
/// Returns the formatted split path as a String
715+
///
716+
/// # Example
717+
///
718+
/// ```
719+
/// use llama_cpp_2::model::LlamaModel;
720+
///
721+
/// let path = LlamaModel::split_path("/models/llama", 2, 4);
722+
/// assert_eq!(path, "/models/llama-00002-of-00004.gguf");
723+
/// ```
724+
pub fn split_path(path_prefix: &str, split_no: i32, split_count: i32) -> String {
725+
let mut buffer = vec![0u8; 1024];
726+
let path_prefix_cstr = CString::new(path_prefix).unwrap_or_else(|_| CString::new("").unwrap());
727+
let len = unsafe {
728+
llama_cpp_sys_2::llama_split_path(
729+
buffer.as_mut_ptr() as *mut c_char,
730+
buffer.len(),
731+
path_prefix_cstr.as_ptr(),
732+
split_no,
733+
split_count,
734+
)
735+
};
736+
737+
if len > 0 && len < buffer.len() as i32 {
738+
buffer.truncate(len as usize);
739+
String::from_utf8(buffer).unwrap_or_else(|_| String::new())
740+
} else {
741+
String::new()
742+
}
743+
}
744+
745+
/// Extract the path prefix from a split filename.
746+
///
747+
/// This function extracts the base path prefix from a split model filename,
748+
/// but only if the split_no and split_count match the pattern in the filename.
749+
///
750+
/// # Arguments
751+
///
752+
/// * `split_path` - The full path to a split file
753+
/// * `split_no` - The expected split number
754+
/// * `split_count` - The expected total number of splits
755+
///
756+
/// # Returns
757+
///
758+
/// Returns `Some(prefix)` if the split pattern matches, `None` otherwise
759+
///
760+
/// # Example
761+
///
762+
/// ```
763+
/// use llama_cpp_2::model::LlamaModel;
764+
///
765+
/// let prefix = LlamaModel::split_prefix("/models/llama-00002-of-00004.gguf", 2, 4);
766+
/// assert_eq!(prefix, Some("/models/llama".to_string()));
767+
///
768+
/// // Returns None if the pattern doesn't match
769+
/// let prefix = LlamaModel::split_prefix("/models/llama-00002-of-00004.gguf", 3, 4);
770+
/// assert_eq!(prefix, None);
771+
/// ```
772+
pub fn split_prefix(split_path: &str, split_no: i32, split_count: i32) -> Option<String> {
773+
let mut buffer = vec![0u8; 1024];
774+
let split_path_cstr = CString::new(split_path).ok()?;
775+
let len = unsafe {
776+
llama_cpp_sys_2::llama_split_prefix(
777+
buffer.as_mut_ptr() as *mut c_char,
778+
buffer.len(),
779+
split_path_cstr.as_ptr(),
780+
split_no,
781+
split_count,
782+
)
783+
};
784+
785+
if len > 0 && len < buffer.len() as i32 {
786+
buffer.truncate(len as usize);
787+
String::from_utf8(buffer).ok()
788+
} else {
789+
None
790+
}
791+
}
792+
625793
/// Initializes a lora adapter from a file.
626794
///
627795
/// # Errors

0 commit comments

Comments
 (0)