1
+ //! Example demonstrating how to load split GGUF models.
2
+ //!
3
+ //! This example shows how to:
4
+ //! - Load a model split across multiple files
5
+ //! - Use utility functions to work with split file naming conventions
6
+ //! - Generate text from a split model
7
+
8
+ use anyhow:: Result ;
9
+ use clap:: Parser ;
10
+ use llama_cpp_2:: {
11
+ context:: params:: LlamaContextParams ,
12
+ llama_backend:: LlamaBackend ,
13
+ llama_batch:: LlamaBatch ,
14
+ model:: { params:: LlamaModelParams , AddBos , LlamaModel } ,
15
+ sampling:: LlamaSampler ,
16
+ } ;
17
+ use std:: io:: { self , Write } ;
18
+ use std:: num:: NonZeroU32 ;
19
+ use std:: path:: { Path , PathBuf } ;
20
+
21
+ /// Command line arguments for the split model example
22
+ #[ derive( Parser , Debug ) ]
23
+ #[ command( author, version, about, long_about = None ) ]
24
+ struct Args {
25
+ /// Paths to the split model files (can be specified multiple times)
26
+ #[ arg( short = 'm' , long = "model" , required = true , num_args = 1 ..) ]
27
+ model_paths : Vec < PathBuf > ,
28
+
29
+ /// Alternatively, provide a prefix and the program will auto-detect splits
30
+ #[ arg( short = 'p' , long = "prefix" , conflicts_with = "model_paths" ) ]
31
+ prefix : Option < String > ,
32
+
33
+ /// Number of splits (required if using --prefix)
34
+ #[ arg( short = 'n' , long = "num-splits" , requires = "prefix" ) ]
35
+ num_splits : Option < u32 > ,
36
+
37
+ /// Prompt to use for generation
38
+ #[ arg( short = 't' , long = "prompt" , default_value = "Once upon a time" ) ]
39
+ prompt : String ,
40
+
41
+ /// Number of tokens to generate
42
+ #[ arg( short = 'g' , long = "n-predict" , default_value_t = 128 ) ]
43
+ n_predict : i32 ,
44
+
45
+ /// Number of GPU layers
46
+ #[ arg( short = 'l' , long = "n-gpu-layers" , default_value_t = 0 ) ]
47
+ n_gpu_layers : u32 ,
48
+
49
+ /// Context size
50
+ #[ arg( short = 'c' , long = "ctx-size" , default_value_t = 2048 ) ]
51
+ ctx_size : u32 ,
52
+
53
+ /// Temperature for sampling
54
+ #[ arg( long = "temp" , default_value_t = 0.8 ) ]
55
+ temperature : f32 ,
56
+
57
+ /// Top-P for sampling
58
+ #[ arg( long = "top-p" , default_value_t = 0.95 ) ]
59
+ top_p : f32 ,
60
+
61
+ /// Seed for random number generation
62
+ #[ arg( long = "seed" , default_value_t = 1234 ) ]
63
+ seed : u32 ,
64
+ }
65
+
66
+ fn main ( ) -> Result < ( ) > {
67
+ let args = Args :: parse ( ) ;
68
+
69
+ // Determine the model paths
70
+ let model_paths = if let Some ( prefix) = args. prefix {
71
+ let num_splits = args. num_splits . expect ( "num-splits required with prefix" ) ;
72
+
73
+ // Generate split paths using the utility function
74
+ let mut paths = Vec :: new ( ) ;
75
+ for i in 1 ..=num_splits {
76
+ let path = LlamaModel :: split_path ( & prefix, i as i32 , num_splits as i32 ) ;
77
+ paths. push ( PathBuf :: from ( path) ) ;
78
+ }
79
+
80
+ println ! ( "Generated split paths:" ) ;
81
+ for path in & paths {
82
+ println ! ( " - {}" , path. display( ) ) ;
83
+ }
84
+
85
+ paths
86
+ } else {
87
+ args. model_paths
88
+ } ;
89
+
90
+ // Verify all split files exist
91
+ for path in & model_paths {
92
+ if !path. exists ( ) {
93
+ eprintln ! ( "Error: Split file not found: {}" , path. display( ) ) ;
94
+ std:: process:: exit ( 1 ) ;
95
+ }
96
+ }
97
+
98
+ println ! ( "Loading model from {} splits..." , model_paths. len( ) ) ;
99
+
100
+ // Initialize the backend
101
+ let backend = LlamaBackend :: init ( ) ?;
102
+
103
+ // Set up model parameters
104
+ let mut model_params = LlamaModelParams :: default ( ) ;
105
+ if args. n_gpu_layers > 0 {
106
+ model_params = model_params. with_n_gpu_layers ( args. n_gpu_layers ) ;
107
+ }
108
+
109
+ // Load the model from splits
110
+ let model = LlamaModel :: load_from_splits ( & backend, & model_paths, & model_params) ?;
111
+ println ! ( "Model loaded successfully!" ) ;
112
+
113
+ // Get model info
114
+ let n_vocab = model. n_vocab ( ) ;
115
+ println ! ( "Model vocabulary size: {}" , n_vocab) ;
116
+
117
+ // Create context
118
+ let ctx_params = LlamaContextParams :: default ( )
119
+ . with_n_ctx ( Some ( NonZeroU32 :: new ( args. ctx_size ) . unwrap ( ) ) ) ;
120
+
121
+ let mut ctx = model. new_context ( & backend, ctx_params) ?;
122
+ println ! ( "Context created with size: {}" , args. ctx_size) ;
123
+
124
+ // Tokenize the prompt
125
+ let tokens = model. str_to_token ( & args. prompt , AddBos :: Always ) ?;
126
+ println ! ( "Prompt tokenized into {} tokens" , tokens. len( ) ) ;
127
+
128
+ // Create batch
129
+ let mut batch = LlamaBatch :: new ( 512 , 1 ) ;
130
+
131
+ // Add tokens to batch
132
+ let last_index = tokens. len ( ) - 1 ;
133
+ for ( i, token) in tokens. iter ( ) . enumerate ( ) {
134
+ let is_last = i == last_index;
135
+ batch. add ( * token, i as i32 , & [ 0 ] , is_last) ?;
136
+ }
137
+
138
+ // Decode the batch
139
+ ctx. decode ( & mut batch) ?;
140
+ println ! ( "Initial prompt processed" ) ;
141
+
142
+ // Set up sampling
143
+ let mut sampler = LlamaSampler :: chain_simple ( [
144
+ LlamaSampler :: temp ( args. temperature ) ,
145
+ LlamaSampler :: top_p ( args. top_p , 1 ) ,
146
+ ] ) ;
147
+
148
+ // Generate text
149
+ print ! ( "{}" , args. prompt) ;
150
+ io:: stdout ( ) . flush ( ) ?;
151
+
152
+ let mut n_cur = batch. n_tokens ( ) ;
153
+ let mut n_decode = 0 ;
154
+
155
+ while n_decode < args. n_predict {
156
+ // Sample the next token
157
+ let new_token = sampler. sample ( & ctx, batch. n_tokens ( ) - 1 ) ;
158
+ sampler. accept ( new_token) ;
159
+
160
+ // Check for EOS
161
+ if model. is_eog_token ( new_token) {
162
+ println ! ( ) ;
163
+ break ;
164
+ }
165
+
166
+ // Print the token
167
+ let piece = model. token_to_str ( new_token, llama_cpp_2:: model:: Special :: Tokenize ) ?;
168
+ print ! ( "{}" , piece) ;
169
+ io:: stdout ( ) . flush ( ) ?;
170
+
171
+ // Prepare the next batch
172
+ batch. clear ( ) ;
173
+ batch. add ( new_token, n_cur, & [ 0 ] , true ) ?;
174
+ n_cur += 1 ;
175
+
176
+ // Decode
177
+ ctx. decode ( & mut batch) ?;
178
+ n_decode += 1 ;
179
+ }
180
+
181
+ println ! ( "\n \n Generation complete!" ) ;
182
+ println ! ( "Generated {} tokens" , n_decode) ;
183
+
184
+ // Demonstrate the split_prefix utility
185
+ if let Some ( first_path) = model_paths. first ( ) {
186
+ if let Some ( path_str) = first_path. to_str ( ) {
187
+ // Try to extract the prefix from the first split file
188
+ if let Some ( prefix) = LlamaModel :: split_prefix ( path_str, 1 , model_paths. len ( ) as i32 ) {
189
+ println ! ( "\n Extracted prefix from first split: {}" , prefix) ;
190
+ }
191
+ }
192
+ }
193
+
194
+ Ok ( ( ) )
195
+ }
0 commit comments