@@ -421,28 +421,28 @@ def _ensure_sync_client_available(self) -> None:
421421
422422 def _tokenize (
423423 self , texts : list [str ], chunk_size : int
424- ) -> tuple [Iterable [int ], list [list [int ] | str ], list [int ]]:
425- """Take the input `texts` and `chunk_size` and return 3 iterables as a tuple .
424+ ) -> tuple [Iterable [int ], list [list [int ] | str ], list [int ], list [ int ] ]:
425+ """Tokenize and batch input texts .
426426
427- We have `batches`, where batches are sets of individual texts
428- we want responses from the openai api. The length of a single batch is
429- `chunk_size` texts.
427+ Splits texts based on `embedding_ctx_length` and groups them into batches
428+ of size `chunk_size`.
430429
431- Each individual text is also split into multiple texts based on the
432- `embedding_ctx_length` parameter (based on number of tokens).
433-
434- This function returns a 3-tuple of the following:
430+ Args:
431+ texts: The list of texts to tokenize.
432+ chunk_size: The maximum number of texts to include in a single batch.
435433
436- _iter: An iterable of the starting index in `tokens` for each *batch*
437- tokens: A list of tokenized texts, where each text has already been split
438- into sub-texts based on the `embedding_ctx_length` parameter. In the
439- case of tiktoken, this is a list of token arrays. In the case of
440- HuggingFace transformers, this is a list of strings.
441- indices: An iterable of the same length as `tokens` that maps each token-array
442- to the index of the original text in `texts`.
434+ Returns:
435+ A tuple containing:
436+ 1. An iterable of starting indices in the token list for each batch.
437+ 2. A list of tokenized texts (token arrays for tiktoken, strings for
438+ HuggingFace).
439+ 3. An iterable mapping each token array to the index of the original
440+ text. Same length as the token list.
441+ 4. A list of token counts for each tokenized text.
443442 """
444443 tokens : list [list [int ] | str ] = []
445444 indices : list [int ] = []
445+ token_counts : list [int ] = []
446446 model_name = self .tiktoken_model_name or self .model
447447
448448 # If tiktoken flag set to False
@@ -474,6 +474,7 @@ def _tokenize(
474474 chunk_text : str = tokenizer .decode (token_chunk )
475475 tokens .append (chunk_text )
476476 indices .append (i )
477+ token_counts .append (len (token_chunk ))
477478 else :
478479 try :
479480 encoding = tiktoken .encoding_for_model (model_name )
@@ -503,6 +504,7 @@ def _tokenize(
503504 for j in range (0 , len (token ), self .embedding_ctx_length ):
504505 tokens .append (token [j : j + self .embedding_ctx_length ])
505506 indices .append (i )
507+ token_counts .append (len (token [j : j + self .embedding_ctx_length ]))
506508
507509 if self .show_progress_bar :
508510 try :
@@ -513,7 +515,7 @@ def _tokenize(
513515 _iter = range (0 , len (tokens ), chunk_size )
514516 else :
515517 _iter = range (0 , len (tokens ), chunk_size )
516- return _iter , tokens , indices
518+ return _iter , tokens , indices , token_counts
517519
518520 # please refer to
519521 # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
@@ -527,12 +529,12 @@ def _get_len_safe_embeddings(
527529 ) -> list [list [float ]]:
528530 """Generate length-safe embeddings for a list of texts.
529531
530- This method handles tokenization and embedding generation, respecting the set
531- embedding context length and chunk size. It supports both tiktoken and
532- HuggingFace tokenizer based on the tiktoken_enabled flag.
532+ This method handles tokenization and embedding generation, respecting the
533+ `embedding_ctx_length` and `chunk_size`. Supports both ` tiktoken` and
534+ HuggingFace `transformers` based on the ` tiktoken_enabled` flag.
533535
534536 Args:
535- texts: A list of texts to embed.
537+ texts: The list of texts to embed.
536538 engine: The engine or model to use for embeddings.
537539 chunk_size: The size of chunks for processing embeddings.
538540
@@ -541,12 +543,8 @@ def _get_len_safe_embeddings(
541543 """
542544 _chunk_size = chunk_size or self .chunk_size
543545 client_kwargs = {** self ._invocation_params , ** kwargs }
544- _iter , tokens , indices = self ._tokenize (texts , _chunk_size )
546+ _iter , tokens , indices , token_counts = self ._tokenize (texts , _chunk_size )
545547 batched_embeddings : list [list [float ]] = []
546- # Calculate token counts per chunk
547- token_counts = [
548- len (t ) if isinstance (t , list ) else len (t .split ()) for t in tokens
549- ]
550548
551549 # Process in batches respecting the token limit
552550 i = 0
@@ -603,12 +601,12 @@ async def _aget_len_safe_embeddings(
603601 ) -> list [list [float ]]:
604602 """Asynchronously generate length-safe embeddings for a list of texts.
605603
606- This method handles tokenization and asynchronous embedding generation,
607- respecting the set embedding context length and chunk size. It supports both
608- `tiktoken` and HuggingFace `tokenizer ` based on the tiktoken_enabled flag.
604+ This method handles tokenization and embedding generation, respecting the
605+ `embedding_ctx_length` and `chunk_size`. Supports both `tiktoken` and
606+ HuggingFace `transformers ` based on the ` tiktoken_enabled` flag.
609607
610608 Args:
611- texts: A list of texts to embed.
609+ texts: The list of texts to embed.
612610 engine: The engine or model to use for embeddings.
613611 chunk_size: The size of chunks for processing embeddings.
614612
@@ -617,14 +615,10 @@ async def _aget_len_safe_embeddings(
617615 """
618616 _chunk_size = chunk_size or self .chunk_size
619617 client_kwargs = {** self ._invocation_params , ** kwargs }
620- _iter , tokens , indices = await run_in_executor (
618+ _iter , tokens , indices , token_counts = await run_in_executor (
621619 None , self ._tokenize , texts , _chunk_size
622620 )
623621 batched_embeddings : list [list [float ]] = []
624- # Calculate token counts per chunk
625- token_counts = [
626- len (t ) if isinstance (t , list ) else len (t .split ()) for t in tokens
627- ]
628622
629623 # Process in batches respecting the token limit
630624 i = 0
@@ -676,12 +670,13 @@ async def empty_embedding() -> list[float]:
676670 def embed_documents (
677671 self , texts : list [str ], chunk_size : int | None = None , ** kwargs : Any
678672 ) -> list [list [float ]]:
679- """Call out to OpenAI's embedding endpoint for embedding search docs.
673+ """Call OpenAI's embedding endpoint to embed search docs.
680674
681675 Args:
682676 texts: The list of texts to embed.
683- chunk_size: The chunk size of embeddings. If `None`, will use the chunk size
684- specified by the class.
677+ chunk_size: The chunk size of embeddings.
678+
679+ If `None`, will use the chunk size specified by the class.
685680 kwargs: Additional keyword arguments to pass to the embedding API.
686681
687682 Returns:
@@ -701,8 +696,8 @@ def embed_documents(
701696 embeddings .extend (r ["embedding" ] for r in response ["data" ])
702697 return embeddings
703698
704- # NOTE: to keep things simple, we assume the list may contain texts longer
705- # than the maximum context and use length-safe embedding function .
699+ # Unconditionally call _get_len_safe_embeddings to handle length safety.
700+ # This could be optimized to avoid double work when all texts are short enough .
706701 engine = cast (str , self .deployment )
707702 return self ._get_len_safe_embeddings (
708703 texts , engine = engine , chunk_size = chunk_size , ** kwargs
@@ -711,12 +706,13 @@ def embed_documents(
711706 async def aembed_documents (
712707 self , texts : list [str ], chunk_size : int | None = None , ** kwargs : Any
713708 ) -> list [list [float ]]:
714- """Call out to OpenAI's embedding endpoint async for embedding search docs.
709+ """Asynchronously call OpenAI's embedding endpoint to embed search docs.
715710
716711 Args:
717712 texts: The list of texts to embed.
718- chunk_size: The chunk size of embeddings. If `None`, will use the chunk size
719- specified by the class.
713+ chunk_size: The chunk size of embeddings.
714+
715+ If `None`, will use the chunk size specified by the class.
720716 kwargs: Additional keyword arguments to pass to the embedding API.
721717
722718 Returns:
@@ -735,8 +731,8 @@ async def aembed_documents(
735731 embeddings .extend (r ["embedding" ] for r in response ["data" ])
736732 return embeddings
737733
738- # NOTE: to keep things simple, we assume the list may contain texts longer
739- # than the maximum context and use length-safe embedding function .
734+ # Unconditionally call _get_len_safe_embeddings to handle length safety.
735+ # This could be optimized to avoid double work when all texts are short enough .
740736 engine = cast (str , self .deployment )
741737 return await self ._aget_len_safe_embeddings (
742738 texts , engine = engine , chunk_size = chunk_size , ** kwargs
0 commit comments