@@ -565,70 +565,6 @@ std::pair<std::string, std::string> common_get_hf_file(
565565// clear LoRA adapters from context, then apply new list of adapters
566566void common_set_adapter_lora (struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
567567
568- //
569- // Batch utils
570- //
571-
572- // convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
573- // this is meant to be temporary
574- struct common_batch {
575- llama_batch_ext_ptr batch;
576- struct batch_token {
577- llama_token token;
578- llama_seq_id seq_id; // only support single seq for now
579- bool logits;
580- };
581- std::vector<batch_token> tokens;
582- int n_outputs = 0 ;
583- common_batch () = default ;
584- common_batch (int32_t n_tokens, int32_t n_seq_max) {
585- batch.reset (llama_batch_ext_init (n_tokens, n_seq_max));
586- tokens.reserve (n_tokens);
587- }
588- void clear () {
589- llama_batch_ext_clear (batch.get ());
590- tokens.clear ();
591- }
592- void add_text (llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
593- llama_batch_ext_add_text (batch.get (), token, pos, &seq_id, 1 , logits);
594- tokens.push_back ({token, seq_id, logits});
595- if (logits) {
596- n_outputs++;
597- }
598- }
599- void add_text_multi_seq (llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
600- llama_batch_ext_add_text (batch.get (), token, pos, seq_ids.data (), seq_ids.size (), logits);
601- tokens.push_back ({token, seq_ids[0 ], logits});
602- if (logits) {
603- n_outputs++;
604- }
605- }
606- void set_logits_last () {
607- if (!tokens.empty ()) {
608- llama_batch_ext_set_output_last (batch.get ());
609- tokens.back ().logits = true ;
610- }
611- }
612- int32_t get_n_tokens () const {
613- return (int32_t )tokens.size ();
614- }
615- llama_batch_ext * get () {
616- return batch.get ();
617- }
618- common_batch get_view (int32_t offset, int32_t n_tokens) {
619- common_batch view;
620- view.batch = llama_batch_ext_ptr (llama_batch_ext_get_view (batch.get (), offset, n_tokens));
621- view.tokens .reserve (n_tokens);
622- for (int32_t i = 0 ; i < n_tokens; i++) {
623- view.tokens .push_back (tokens[offset + i]);
624- if (tokens[offset + i].logits ) {
625- view.n_outputs ++;
626- }
627- }
628- return view;
629- }
630- };
631-
632568//
633569// Token utils
634570//
0 commit comments