@@ -5,35 +5,19 @@ enum LlamaError: Error {
55 case couldNotInitializeContext
66}
77
8- func llama_batch_clear( _ batch: inout llama_batch ) {
9- batch. n_tokens = 0
10- }
11-
12- func llama_batch_add( _ batch: inout llama_batch , _ id: llama_token , _ pos: llama_pos , _ seq_ids: [ llama_seq_id ] , _ logits: Bool ) {
13- batch. token [ Int ( batch. n_tokens) ] = id
14- batch. pos [ Int ( batch. n_tokens) ] = pos
15- batch. n_seq_id [ Int ( batch. n_tokens) ] = Int32 ( seq_ids. count)
16- for i in 0 ..< seq_ids. count {
17- batch. seq_id [ Int ( batch. n_tokens) ] ![ Int ( i) ] = seq_ids [ i]
18- }
19- batch. logits [ Int ( batch. n_tokens) ] = logits ? 1 : 0
20-
21- batch. n_tokens += 1
22- }
23-
248actor LlamaContext {
259 private var model : OpaquePointer
2610 private var context : OpaquePointer
2711 private var vocab : OpaquePointer
2812 private var sampling : UnsafeMutablePointer < llama_sampler >
29- private var batch : llama_batch
13+ private var batch : OpaquePointer
3014 private var tokens_list : [ llama_token ]
3115 var is_done : Bool = false
3216
3317 /// This variable is used to store temporarily invalid cchars
3418 private var temporary_invalid_cchars : [ CChar ]
3519
36- var n_len : Int32 = 1024
20+ var n_len : Int32 = 128
3721 var n_cur : Int32 = 0
3822
3923 var n_decode : Int32 = 0
@@ -42,7 +26,7 @@ actor LlamaContext {
4226 self . model = model
4327 self . context = context
4428 self . tokens_list = [ ]
45- self . batch = llama_batch_init ( 512 , 0 , 1 )
29+ self . batch = llama_batch_ext_init ( 512 , 1 )
4630 self . temporary_invalid_cchars = [ ]
4731 let sparams = llama_sampler_chain_default_params ( )
4832 self . sampling = llama_sampler_chain_init ( sparams)
@@ -53,7 +37,7 @@ actor LlamaContext {
5337
5438 deinit {
5539 llama_sampler_free ( sampling)
56- llama_batch_free ( batch)
40+ llama_batch_ext_free ( batch)
5741 llama_model_free ( model)
5842 llama_free ( context)
5943 llama_backend_free ( )
@@ -111,7 +95,7 @@ actor LlamaContext {
11195 }
11296
11397 func get_n_tokens( ) -> Int32 {
114- return batch. n_tokens;
98+ return llama_batch_ext_get_n_tokens ( batch)
11599 }
116100
117101 func completion_init( text: String ) {
@@ -133,25 +117,25 @@ actor LlamaContext {
133117 print ( String ( cString: token_to_piece ( token: id) + [ 0 ] ) )
134118 }
135119
136- llama_batch_clear ( & batch)
120+ llama_batch_ext_clear ( batch)
137121
138122 for i1 in 0 ..< tokens_list. count {
139123 let i = Int ( i1)
140- llama_batch_add ( & batch, tokens_list [ i] , Int32 ( i) , [ 0 ] , false )
124+ llama_batch_ext_add_text ( batch, tokens_list [ i] , Int32 ( i) , [ llama_seq_id ( 0 ) ] , 1 , false )
141125 }
142- batch . logits [ Int ( batch. n_tokens ) - 1 ] = 1 // true
126+ llama_batch_ext_set_output_last ( batch)
143127
144- if llama_decode ( context, batch) != 0 {
145- print ( " llama_decode () failed" )
128+ if llama_decode_ext ( context, batch) != 0 {
129+ print ( " llama_decode_ext () failed" )
146130 }
147131
148- n_cur = batch. n_tokens
132+ n_cur = llama_batch_ext_get_n_tokens ( batch)
149133 }
150134
151135 func completion_loop( ) -> String {
152136 var new_token_id : llama_token = 0
153137
154- new_token_id = llama_sampler_sample ( sampling, context, batch. n_tokens - 1 )
138+ new_token_id = llama_sampler_sample ( sampling, context, llama_batch_ext_get_n_tokens ( batch) - 1 )
155139
156140 if llama_vocab_is_eog ( vocab, new_token_id) || n_cur == n_len {
157141 print ( " \n " )
@@ -178,13 +162,13 @@ actor LlamaContext {
178162 print ( new_token_str)
179163 // tokens_list.append(new_token_id)
180164
181- llama_batch_clear ( & batch)
182- llama_batch_add ( & batch, new_token_id, n_cur, [ 0 ] , true )
165+ llama_batch_ext_clear ( batch)
166+ llama_batch_ext_add_text ( batch, new_token_id, n_cur, [ llama_seq_id ( 0 ) ] , 1 , true )
183167
184168 n_decode += 1
185169 n_cur += 1
186170
187- if llama_decode ( context, batch) != 0 {
171+ if llama_decode_ext ( context, batch) != 0 {
188172 print ( " failed to evaluate llama! " )
189173 }
190174
@@ -201,21 +185,21 @@ actor LlamaContext {
201185 for _ in 0 ..< nr {
202186 // bench prompt processing
203187
204- llama_batch_clear ( & batch)
188+ llama_batch_ext_clear ( batch)
205189
206190 let n_tokens = pp
207191
208192 for i in 0 ..< n_tokens {
209- llama_batch_add ( & batch, 0 , Int32 ( i) , [ 0 ] , false )
193+ llama_batch_ext_add_text ( batch, 0 , Int32 ( i) , [ llama_seq_id ( 0 ) ] , 1 , false )
210194 }
211- batch . logits [ Int ( batch. n_tokens ) - 1 ] = 1 // true
195+ llama_batch_ext_set_output_last ( batch)
212196
213197 llama_kv_self_clear ( context)
214198
215199 let t_pp_start = DispatchTime . now ( ) . uptimeNanoseconds / 1000 ;
216200
217- if llama_decode ( context, batch) != 0 {
218- print ( " llama_decode () failed during prompt" )
201+ if llama_decode_ext ( context, batch) != 0 {
202+ print ( " llama_decode_ext () failed during prompt" )
219203 }
220204 llama_synchronize ( context)
221205
@@ -228,14 +212,14 @@ actor LlamaContext {
228212 let t_tg_start = DispatchTime . now ( ) . uptimeNanoseconds / 1000 ;
229213
230214 for i in 0 ..< tg {
231- llama_batch_clear ( & batch)
215+ llama_batch_ext_clear ( batch)
232216
233217 for j in 0 ..< pl {
234- llama_batch_add ( & batch, 0 , Int32 ( i) , [ Int32 ( j) ] , true )
218+ llama_batch_ext_add_text ( batch, 0 , Int32 ( i) , [ llama_seq_id ( Int32 ( j) ) ] , 1 , true )
235219 }
236220
237- if llama_decode ( context, batch) != 0 {
238- print ( " llama_decode () failed during text generation" )
221+ if llama_decode_ext ( context, batch) != 0 {
222+ print ( " llama_decode_ext () failed during text generation" )
239223 }
240224 llama_synchronize ( context)
241225 }
0 commit comments