1+ #[ cfg( feature = "llama" ) ]
2+ mod llama;
3+
4+ #[ cfg( feature = "llama" ) ]
5+ pub use llama:: { test_infer_paralle, Task , WorkerSeed } ;
6+
17use gguf:: {
28 ext:: { utok, Mmap } ,
39 map_files, GGufMetaMapExt , GGufModel , Message , Tokenizer ,
410} ;
511use std:: {
612 env:: { var, var_os} ,
713 fmt,
8- iter:: zip,
914 path:: { Path , PathBuf } ,
1015 str:: FromStr ,
11- sync:: {
12- mpsc:: { self , Sender } ,
13- Once ,
14- } ,
16+ sync:: Once ,
1517 time:: { Duration , Instant } ,
1618} ;
1719
@@ -26,13 +28,27 @@ pub struct Inference {
2628 pub max_steps : usize ,
2729}
2830
31+ mod env {
32+ pub const TEST_MODEL : & str = "TEST_MODEL" ;
33+ pub const TEST_IMAGE : & str = "TEST_IMAGE" ;
34+ pub const DEVICES : & str = "DEVICES" ;
35+ pub const PROMPT : & str = "PROMPT" ;
36+ pub const AS_USER : & str = "AS_USER" ;
37+ pub const TEMPERATURE : & str = "TEMPERATURE" ;
38+ pub const TOP_P : & str = "TOP_P" ;
39+ pub const TOP_K : & str = "TOP_K" ;
40+ pub const MAX_STEPS : & str = "MAX_STEPS" ;
41+ pub const ROLL_CACHE_SIZE : & str = "ROLL_CACHE_SIZE" ;
42+ }
43+ use env:: * ;
44+
2945impl Inference {
3046 pub fn load ( ) -> Option < Self > {
3147 static ONCE : Once = Once :: new ( ) ;
3248 ONCE . call_once ( env_logger:: init) ;
3349
34- let Some ( path) = var_os ( " TEST_MODEL" ) else {
35- println ! ( "TEST_MODEL not set" ) ;
50+ let Some ( path) = var_os ( TEST_MODEL ) else {
51+ println ! ( "{ TEST_MODEL} not set" ) ;
3652 return None ;
3753 } ;
3854 let path = Path :: new ( & path) ;
@@ -50,26 +66,26 @@ impl Inference {
5066
5167 Some ( Self {
5268 model : map_files ( path) ,
53- devices : var ( " DEVICES" ) . ok ( ) ,
54- prompt : var ( " PROMPT" ) . unwrap_or_else ( |_| String :: from ( "Once upon a time," ) ) ,
55- as_user : var ( " AS_USER" ) . ok ( ) . is_some_and ( |s| !s. is_empty ( ) ) ,
56- temperature : parse ( " TEMPERATURE" , 0. ) ,
57- top_p : parse ( " TOP_P" , 1. ) ,
58- top_k : parse ( " TOP_K" , usize:: MAX ) ,
59- max_steps : parse ( " MAX_STEPS" , usize:: MAX ) ,
69+ devices : var ( DEVICES ) . ok ( ) ,
70+ prompt : var ( PROMPT ) . unwrap_or_else ( |_| String :: from ( "Once upon a time," ) ) ,
71+ as_user : var ( AS_USER ) . ok ( ) . is_some_and ( |s| !s. is_empty ( ) ) ,
72+ temperature : parse ( TEMPERATURE , 0. ) ,
73+ top_p : parse ( TOP_P , 1. ) ,
74+ top_k : parse ( TOP_K , usize:: MAX ) ,
75+ max_steps : parse ( MAX_STEPS , usize:: MAX ) ,
6076 } )
6177 }
6278}
6379
6480pub fn load_roll_cache_size ( ) -> usize {
65- var ( " ROLL_CACHE_SIZE" )
81+ var ( ROLL_CACHE_SIZE )
6682 . ok ( )
6783 . and_then ( |s| s. parse ( ) . ok ( ) )
6884 . unwrap_or ( usize:: MAX )
6985}
7086
7187pub fn image ( ) -> Option < PathBuf > {
72- var_os ( " TEST_IMAGE" ) . map ( PathBuf :: from)
88+ var_os ( TEST_IMAGE ) . map ( PathBuf :: from)
7389}
7490
7591pub struct TokenizerAndPrompt {
@@ -179,71 +195,3 @@ pub fn test_infer(
179195 ]
180196 }
181197}
182-
183- #[ cfg( feature = "llama" ) ]
184- pub fn test_infer_paralle < ' w > (
185- model : & llama:: LlamaStorage < & ' w [ u8 ] > ,
186- senders : Box < [ mpsc:: Sender < Task > ] > ,
187- eos : utok ,
188- tokenizer : Tokenizer ,
189- prompt : & str ,
190- max_steps : usize ,
191- ) {
192- use tensor:: Blob ;
193-
194- let ( next, next_recv) = mpsc:: channel ( ) ;
195- test_infer ( eos, tokenizer, prompt, max_steps, |input, pos| {
196- let mut embd = model. meta . embd ( input. len ( ) ) . map ( Blob :: new) . take ( ) ;
197-
198- let d = embd. len ( ) / input. len ( ) ;
199- for ( i, & tok) in input. iter ( ) . enumerate ( ) {
200- embd[ i * d..] [ ..d] . copy_from_slice ( & model. token_embd [ tok as usize * d..] [ ..d] ) ;
201- }
202-
203- for sender in & senders {
204- sender
205- . send ( Task {
206- nt : input. len ( ) ,
207- pos,
208- embd : embd. as_ptr ( ) ,
209- next : next. clone ( ) ,
210- } )
211- . unwrap ( )
212- }
213- next_recv. recv ( ) . unwrap ( )
214- } ) ;
215- }
216-
217- pub struct Task {
218- pub nt : usize ,
219- pub pos : usize ,
220- pub embd : * const u8 ,
221- pub next : mpsc:: Sender < utok > ,
222- }
223-
224- unsafe impl Send for Task { }
225-
226- pub struct WorkerSeed < N > {
227- pub tasks : mpsc:: Receiver < Task > ,
228- pub node : N ,
229- }
230-
231- impl < N > WorkerSeed < N > {
232- pub fn new ( nodes : Vec < N > ) -> ( Vec < Self > , Vec < Sender < Task > > ) {
233- let n = nodes. len ( ) ;
234-
235- let mut tasks = Vec :: with_capacity ( n) ;
236- let mut senders = Vec :: with_capacity ( n) ;
237- for _ in 0 ..n {
238- let ( sender, receiver) = std:: sync:: mpsc:: channel ( ) ;
239- tasks. push ( receiver) ;
240- senders. push ( sender) ;
241- }
242- (
243- zip ( nodes, tasks)
244- . map ( |( node, tasks) | Self { node, tasks } )
245- . collect ( ) ,
246- senders,
247- )
248- }
249- }
0 commit comments