Skip to content

Commit 099c042

Browse files
authored
refactor(openai): embedding utils and calculations (#33982)
Now returns (`_iter`, `tokens`, `indices`, token_counts`). The `token_counts` are calculated directly during tokenization, which is more accurate and efficient than splitting strings later.
1 parent 2d4f00a commit 099c042

File tree

3 files changed

+46
-48
lines changed

3 files changed

+46
-48
lines changed

libs/partners/openai/langchain_openai/embeddings/base.py

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -421,28 +421,28 @@ def _ensure_sync_client_available(self) -> None:
421421

422422
def _tokenize(
423423
self, texts: list[str], chunk_size: int
424-
) -> tuple[Iterable[int], list[list[int] | str], list[int]]:
425-
"""Take the input `texts` and `chunk_size` and return 3 iterables as a tuple.
424+
) -> tuple[Iterable[int], list[list[int] | str], list[int], list[int]]:
425+
"""Tokenize and batch input texts.
426426
427-
We have `batches`, where batches are sets of individual texts
428-
we want responses from the openai api. The length of a single batch is
429-
`chunk_size` texts.
427+
Splits texts based on `embedding_ctx_length` and groups them into batches
428+
of size `chunk_size`.
430429
431-
Each individual text is also split into multiple texts based on the
432-
`embedding_ctx_length` parameter (based on number of tokens).
433-
434-
This function returns a 3-tuple of the following:
430+
Args:
431+
texts: The list of texts to tokenize.
432+
chunk_size: The maximum number of texts to include in a single batch.
435433
436-
_iter: An iterable of the starting index in `tokens` for each *batch*
437-
tokens: A list of tokenized texts, where each text has already been split
438-
into sub-texts based on the `embedding_ctx_length` parameter. In the
439-
case of tiktoken, this is a list of token arrays. In the case of
440-
HuggingFace transformers, this is a list of strings.
441-
indices: An iterable of the same length as `tokens` that maps each token-array
442-
to the index of the original text in `texts`.
434+
Returns:
435+
A tuple containing:
436+
1. An iterable of starting indices in the token list for each batch.
437+
2. A list of tokenized texts (token arrays for tiktoken, strings for
438+
HuggingFace).
439+
3. An iterable mapping each token array to the index of the original
440+
text. Same length as the token list.
441+
4. A list of token counts for each tokenized text.
443442
"""
444443
tokens: list[list[int] | str] = []
445444
indices: list[int] = []
445+
token_counts: list[int] = []
446446
model_name = self.tiktoken_model_name or self.model
447447

448448
# If tiktoken flag set to False
@@ -474,6 +474,7 @@ def _tokenize(
474474
chunk_text: str = tokenizer.decode(token_chunk)
475475
tokens.append(chunk_text)
476476
indices.append(i)
477+
token_counts.append(len(token_chunk))
477478
else:
478479
try:
479480
encoding = tiktoken.encoding_for_model(model_name)
@@ -503,6 +504,7 @@ def _tokenize(
503504
for j in range(0, len(token), self.embedding_ctx_length):
504505
tokens.append(token[j : j + self.embedding_ctx_length])
505506
indices.append(i)
507+
token_counts.append(len(token[j : j + self.embedding_ctx_length]))
506508

507509
if self.show_progress_bar:
508510
try:
@@ -513,7 +515,7 @@ def _tokenize(
513515
_iter = range(0, len(tokens), chunk_size)
514516
else:
515517
_iter = range(0, len(tokens), chunk_size)
516-
return _iter, tokens, indices
518+
return _iter, tokens, indices, token_counts
517519

518520
# please refer to
519521
# https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
@@ -527,12 +529,12 @@ def _get_len_safe_embeddings(
527529
) -> list[list[float]]:
528530
"""Generate length-safe embeddings for a list of texts.
529531
530-
This method handles tokenization and embedding generation, respecting the set
531-
embedding context length and chunk size. It supports both tiktoken and
532-
HuggingFace tokenizer based on the tiktoken_enabled flag.
532+
This method handles tokenization and embedding generation, respecting the
533+
`embedding_ctx_length` and `chunk_size`. Supports both `tiktoken` and
534+
HuggingFace `transformers` based on the `tiktoken_enabled` flag.
533535
534536
Args:
535-
texts: A list of texts to embed.
537+
texts: The list of texts to embed.
536538
engine: The engine or model to use for embeddings.
537539
chunk_size: The size of chunks for processing embeddings.
538540
@@ -541,12 +543,8 @@ def _get_len_safe_embeddings(
541543
"""
542544
_chunk_size = chunk_size or self.chunk_size
543545
client_kwargs = {**self._invocation_params, **kwargs}
544-
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
546+
_iter, tokens, indices, token_counts = self._tokenize(texts, _chunk_size)
545547
batched_embeddings: list[list[float]] = []
546-
# Calculate token counts per chunk
547-
token_counts = [
548-
len(t) if isinstance(t, list) else len(t.split()) for t in tokens
549-
]
550548

551549
# Process in batches respecting the token limit
552550
i = 0
@@ -603,12 +601,12 @@ async def _aget_len_safe_embeddings(
603601
) -> list[list[float]]:
604602
"""Asynchronously generate length-safe embeddings for a list of texts.
605603
606-
This method handles tokenization and asynchronous embedding generation,
607-
respecting the set embedding context length and chunk size. It supports both
608-
`tiktoken` and HuggingFace `tokenizer` based on the tiktoken_enabled flag.
604+
This method handles tokenization and embedding generation, respecting the
605+
`embedding_ctx_length` and `chunk_size`. Supports both `tiktoken` and
606+
HuggingFace `transformers` based on the `tiktoken_enabled` flag.
609607
610608
Args:
611-
texts: A list of texts to embed.
609+
texts: The list of texts to embed.
612610
engine: The engine or model to use for embeddings.
613611
chunk_size: The size of chunks for processing embeddings.
614612
@@ -617,14 +615,10 @@ async def _aget_len_safe_embeddings(
617615
"""
618616
_chunk_size = chunk_size or self.chunk_size
619617
client_kwargs = {**self._invocation_params, **kwargs}
620-
_iter, tokens, indices = await run_in_executor(
618+
_iter, tokens, indices, token_counts = await run_in_executor(
621619
None, self._tokenize, texts, _chunk_size
622620
)
623621
batched_embeddings: list[list[float]] = []
624-
# Calculate token counts per chunk
625-
token_counts = [
626-
len(t) if isinstance(t, list) else len(t.split()) for t in tokens
627-
]
628622

629623
# Process in batches respecting the token limit
630624
i = 0
@@ -676,12 +670,13 @@ async def empty_embedding() -> list[float]:
676670
def embed_documents(
677671
self, texts: list[str], chunk_size: int | None = None, **kwargs: Any
678672
) -> list[list[float]]:
679-
"""Call out to OpenAI's embedding endpoint for embedding search docs.
673+
"""Call OpenAI's embedding endpoint to embed search docs.
680674
681675
Args:
682676
texts: The list of texts to embed.
683-
chunk_size: The chunk size of embeddings. If `None`, will use the chunk size
684-
specified by the class.
677+
chunk_size: The chunk size of embeddings.
678+
679+
If `None`, will use the chunk size specified by the class.
685680
kwargs: Additional keyword arguments to pass to the embedding API.
686681
687682
Returns:
@@ -701,8 +696,8 @@ def embed_documents(
701696
embeddings.extend(r["embedding"] for r in response["data"])
702697
return embeddings
703698

704-
# NOTE: to keep things simple, we assume the list may contain texts longer
705-
# than the maximum context and use length-safe embedding function.
699+
# Unconditionally call _get_len_safe_embeddings to handle length safety.
700+
# This could be optimized to avoid double work when all texts are short enough.
706701
engine = cast(str, self.deployment)
707702
return self._get_len_safe_embeddings(
708703
texts, engine=engine, chunk_size=chunk_size, **kwargs
@@ -711,12 +706,13 @@ def embed_documents(
711706
async def aembed_documents(
712707
self, texts: list[str], chunk_size: int | None = None, **kwargs: Any
713708
) -> list[list[float]]:
714-
"""Call out to OpenAI's embedding endpoint async for embedding search docs.
709+
"""Asynchronously call OpenAI's embedding endpoint to embed search docs.
715710
716711
Args:
717712
texts: The list of texts to embed.
718-
chunk_size: The chunk size of embeddings. If `None`, will use the chunk size
719-
specified by the class.
713+
chunk_size: The chunk size of embeddings.
714+
715+
If `None`, will use the chunk size specified by the class.
720716
kwargs: Additional keyword arguments to pass to the embedding API.
721717
722718
Returns:
@@ -735,8 +731,8 @@ async def aembed_documents(
735731
embeddings.extend(r["embedding"] for r in response["data"])
736732
return embeddings
737733

738-
# NOTE: to keep things simple, we assume the list may contain texts longer
739-
# than the maximum context and use length-safe embedding function.
734+
# Unconditionally call _get_len_safe_embeddings to handle length safety.
735+
# This could be optimized to avoid double work when all texts are short enough.
740736
engine = cast(str, self.deployment)
741737
return await self._aget_len_safe_embeddings(
742738
texts, engine=engine, chunk_size=chunk_size, **kwargs

libs/partners/openai/tests/unit_tests/embeddings/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_embed_documents_with_custom_chunk_size() -> None:
3333
]
3434

3535
result = embeddings.embed_documents(texts, chunk_size=custom_chunk_size)
36-
_, tokens, __ = embeddings._tokenize(texts, custom_chunk_size)
36+
_, tokens, __, ___ = embeddings._tokenize(texts, custom_chunk_size)
3737
mock_create.call_args
3838
mock_create.assert_any_call(input=tokens[0:3], **embeddings._invocation_params)
3939
mock_create.assert_any_call(input=tokens[3:4], **embeddings._invocation_params)

libs/partners/openai/uv.lock

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)