@@ -762,7 +762,7 @@ def embed(
762762 """
763763 assert self ._ctx .ctx is not None
764764 n_embd = self .n_embd ()
765- n_ctx = self .n_ctx ()
765+ n_batch = self .n_batch
766766
767767 if self .context_params .embedding == False :
768768 raise RuntimeError (
@@ -782,54 +782,55 @@ def embed(
782782
783783 # decode and fetch embeddings
784784 data : List [List [float ]] = []
785- def decode_batch (sizes : List [ int ] ):
785+ def decode_batch (n_seq : int ):
786786 assert self ._ctx .ctx is not None
787787 llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
788788 self ._ctx .decode (self ._batch )
789789 self ._batch .reset ()
790790
791791 # store embeddings
792- for i , s in enumerate ( sizes ):
793- embedding = llama_cpp .llama_get_embeddings_ith (self ._ctx .ctx , i )[
792+ for i in range ( n_seq ):
793+ embedding : List [ float ] = llama_cpp .llama_get_embeddings_ith (self ._ctx .ctx , i )[
794794 :n_embd
795795 ]
796- norm = np .linalg .norm (embedding ) if normalize else s
797- embedding : List [float ] = [v / float (norm ) for v in embedding ]
796+ if normalize :
797+ norm = float (np .linalg .norm (embedding ))
798+ embedding = [v / norm for v in embedding ]
798799 data .append (embedding )
799800
800801 # init state
801802 total_tokens = 0
802803 t_batch = 0
803- s_sizes : List [ int ] = []
804+ p_batch = 0
804805
805806 # accumulate batches and encode
806807 for text in inputs :
807808 tokens = self .tokenize (text .encode ("utf-8" ))
808809 if truncate :
809- tokens = tokens [:n_ctx ]
810+ tokens = tokens [:n_batch ]
810811
811812 n_tokens = len (tokens )
812813 total_tokens += n_tokens
813814
814815 # check for overrun
815- if n_tokens > n_ctx :
816+ if n_tokens > n_batch :
816817 raise ValueError (
817- f"Requested tokens ({ n_tokens } ) exceed context window of { n_ctx } "
818+ f"Requested tokens ({ n_tokens } ) exceed batch size of { n_batch } "
818819 )
819820
820821 # time to eval batch
821- if t_batch + n_tokens > self . _n_ctx :
822- decode_batch (s_sizes )
822+ if t_batch + n_tokens > n_batch :
823+ decode_batch (p_batch )
823824 t_batch = 0
824- s_sizes = []
825+ p_batch = 0
825826
826827 # add to batch
827- self ._batch .add_sequence (tokens , len ( s_sizes ) , False )
828+ self ._batch .add_sequence (tokens , p_batch , False )
828829 t_batch += n_tokens
829- s_sizes . append ( n_tokens )
830+ p_batch += 1
830831
831832 # hanlde last batch
832- decode_batch (s_sizes )
833+ decode_batch (p_batch )
833834
834835 if self .verbose :
835836 llama_cpp .llama_print_timings (self ._ctx .ctx )
0 commit comments