Skip to content

Commit 2d4f00a

Browse files
KaparthyReddyKaparthy Reddyccurmemdrxy
authored
fix(openai): Respect 300k token limit for embeddings API requests (#33668)
## Description Fixes #31227 - Resolves the issue where `OpenAIEmbeddings` exceeds OpenAI's 300,000 token per request limit, causing 400 BadRequest errors. ## Problem When embedding large document sets, LangChain would send batches containing more than 300,000 tokens in a single API request, causing this error: ``` openai.BadRequestError: Error code: 400 - {'error': {'message': 'Requested 673477 tokens, max 300000 tokens per request'}} ``` The issue occurred because: - The code chunks texts by `embedding_ctx_length` (8191 tokens per chunk) - Then batches chunks by `chunk_size` (default 1000 chunks per request) - **But didn't check**: Total tokens per batch against OpenAI's 300k limit - Result: `1000 chunks × 8191 tokens = 8,191,000 tokens` → Exceeds limit! ## Solution This PR implements dynamic batching that respects the 300k token limit: 1. **Added constant**: `MAX_TOKENS_PER_REQUEST = 300000` 2. **Track token counts**: Calculate actual tokens for each chunk 3. **Dynamic batching**: Instead of fixed `chunk_size` batches, accumulate chunks until approaching the 300k limit 4. **Applied to both sync and async**: Fixed both `_get_len_safe_embeddings` and `_aget_len_safe_embeddings` ## Changes - Modified `langchain_openai/embeddings/base.py`: - Added `MAX_TOKENS_PER_REQUEST` constant - Replaced fixed-size batching with token-aware dynamic batching - Applied to both sync (line ~478) and async (line ~527) methods - Added test in `tests/unit_tests/embeddings/test_base.py`: - `test_embeddings_respects_token_limit()` - Verifies large document sets are properly batched ## Testing All existing tests pass (280 passed, 4 xfailed, 1 xpassed). New test verifies: - Large document sets (500 texts × 1000 tokens = 500k tokens) are split into multiple API calls - Each API call respects the 300k token limit ## Usage After this fix, users can embed large document sets without errors: ```python from langchain_openai import OpenAIEmbeddings from langchain_chroma import Chroma from langchain_text_splitters import CharacterTextSplitter # This will now work without exceeding token limits embeddings = OpenAIEmbeddings() documents = CharacterTextSplitter().split_documents(large_documents) Chroma.from_documents(documents, embeddings) ``` Resolves #31227 --------- Co-authored-by: Kaparthy Reddy <[email protected]> Co-authored-by: Chester Curme <[email protected]> Co-authored-by: Mason Daugherty <[email protected]> Co-authored-by: Mason Daugherty <[email protected]>
1 parent 9bd401a commit 2d4f00a

File tree

2 files changed

+115
-11
lines changed

2 files changed

+115
-11
lines changed

libs/partners/openai/langchain_openai/embeddings/base.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
logger = logging.getLogger(__name__)
2121

22+
MAX_TOKENS_PER_REQUEST = 300000
23+
"""API limit per request for embedding tokens."""
24+
2225

2326
def _process_batched_chunked_embeddings(
2427
num_texts: int,
@@ -524,9 +527,9 @@ def _get_len_safe_embeddings(
524527
) -> list[list[float]]:
525528
"""Generate length-safe embeddings for a list of texts.
526529
527-
This method handles tokenization and embedding generation, respecting the
528-
set embedding context length and chunk size. It supports both tiktoken
529-
and HuggingFace tokenizer based on the tiktoken_enabled flag.
530+
This method handles tokenization and embedding generation, respecting the set
531+
embedding context length and chunk size. It supports both tiktoken and
532+
HuggingFace tokenizer based on the tiktoken_enabled flag.
530533
531534
Args:
532535
texts: A list of texts to embed.
@@ -540,14 +543,38 @@ def _get_len_safe_embeddings(
540543
client_kwargs = {**self._invocation_params, **kwargs}
541544
_iter, tokens, indices = self._tokenize(texts, _chunk_size)
542545
batched_embeddings: list[list[float]] = []
543-
for i in _iter:
544-
response = self.client.create(
545-
input=tokens[i : i + _chunk_size], **client_kwargs
546-
)
546+
# Calculate token counts per chunk
547+
token_counts = [
548+
len(t) if isinstance(t, list) else len(t.split()) for t in tokens
549+
]
550+
551+
# Process in batches respecting the token limit
552+
i = 0
553+
while i < len(tokens):
554+
# Determine how many chunks we can include in this batch
555+
batch_token_count = 0
556+
batch_end = i
557+
558+
for j in range(i, min(i + _chunk_size, len(tokens))):
559+
chunk_tokens = token_counts[j]
560+
# Check if adding this chunk would exceed the limit
561+
if batch_token_count + chunk_tokens > MAX_TOKENS_PER_REQUEST:
562+
if batch_end == i:
563+
# Single chunk exceeds limit - handle it anyway
564+
batch_end = j + 1
565+
break
566+
batch_token_count += chunk_tokens
567+
batch_end = j + 1
568+
569+
# Make API call with this batch
570+
batch_tokens = tokens[i:batch_end]
571+
response = self.client.create(input=batch_tokens, **client_kwargs)
547572
if not isinstance(response, dict):
548573
response = response.model_dump()
549574
batched_embeddings.extend(r["embedding"] for r in response["data"])
550575

576+
i = batch_end
577+
551578
embeddings = _process_batched_chunked_embeddings(
552579
len(texts), tokens, batched_embeddings, indices, self.skip_empty
553580
)
@@ -594,15 +621,40 @@ async def _aget_len_safe_embeddings(
594621
None, self._tokenize, texts, _chunk_size
595622
)
596623
batched_embeddings: list[list[float]] = []
597-
for i in range(0, len(tokens), _chunk_size):
624+
# Calculate token counts per chunk
625+
token_counts = [
626+
len(t) if isinstance(t, list) else len(t.split()) for t in tokens
627+
]
628+
629+
# Process in batches respecting the token limit
630+
i = 0
631+
while i < len(tokens):
632+
# Determine how many chunks we can include in this batch
633+
batch_token_count = 0
634+
batch_end = i
635+
636+
for j in range(i, min(i + _chunk_size, len(tokens))):
637+
chunk_tokens = token_counts[j]
638+
# Check if adding this chunk would exceed the limit
639+
if batch_token_count + chunk_tokens > MAX_TOKENS_PER_REQUEST:
640+
if batch_end == i:
641+
# Single chunk exceeds limit - handle it anyway
642+
batch_end = j + 1
643+
break
644+
batch_token_count += chunk_tokens
645+
batch_end = j + 1
646+
647+
# Make API call with this batch
648+
batch_tokens = tokens[i:batch_end]
598649
response = await self.async_client.create(
599-
input=tokens[i : i + _chunk_size], **client_kwargs
650+
input=batch_tokens, **client_kwargs
600651
)
601-
602652
if not isinstance(response, dict):
603653
response = response.model_dump()
604654
batched_embeddings.extend(r["embedding"] for r in response["data"])
605655

656+
i = batch_end
657+
606658
embeddings = _process_batched_chunked_embeddings(
607659
len(texts), tokens, batched_embeddings, indices, self.skip_empty
608660
)

libs/partners/openai/tests/unit_tests/embeddings/test_base.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
2-
from unittest.mock import patch
2+
from typing import Any
3+
from unittest.mock import Mock, patch
34

45
import pytest
6+
from pydantic import SecretStr
57

68
from langchain_openai import OpenAIEmbeddings
79

@@ -96,3 +98,53 @@ async def test_embed_with_kwargs_async() -> None:
9698
mock_create.assert_any_call(input=texts, **client_kwargs)
9799

98100
assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
101+
102+
103+
def test_embeddings_respects_token_limit() -> None:
104+
"""Test that embeddings respect the 300k token per request limit."""
105+
# Create embeddings instance
106+
embeddings = OpenAIEmbeddings(
107+
model="text-embedding-ada-002", api_key=SecretStr("test-key")
108+
)
109+
110+
call_counts = []
111+
112+
def mock_create(**kwargs: Any) -> Mock:
113+
input_ = kwargs["input"]
114+
# Track how many tokens in this call
115+
if isinstance(input_, list):
116+
total_tokens = sum(
117+
len(t) if isinstance(t, list) else len(t.split()) for t in input_
118+
)
119+
call_counts.append(total_tokens)
120+
# Verify this call doesn't exceed limit
121+
assert total_tokens <= 300000, (
122+
f"Batch exceeded token limit: {total_tokens} tokens"
123+
)
124+
125+
# Return mock response
126+
mock_response = Mock()
127+
mock_response.model_dump.return_value = {
128+
"data": [
129+
{"embedding": [0.1] * 1536}
130+
for _ in range(len(input_) if isinstance(input_, list) else 1)
131+
]
132+
}
133+
return mock_response
134+
135+
embeddings.client.create = mock_create
136+
137+
# Create a scenario that would exceed 300k tokens in a single batch
138+
# with default chunk_size=1000
139+
# Simulate 500 texts with ~1000 tokens each = 500k tokens total
140+
large_texts = ["word " * 1000 for _ in range(500)]
141+
142+
# This should not raise an error anymore
143+
embeddings.embed_documents(large_texts)
144+
145+
# Verify we made multiple API calls to respect the limit
146+
assert len(call_counts) > 1, "Should have split into multiple batches"
147+
148+
# Verify each call respected the limit
149+
for count in call_counts:
150+
assert count <= 300000, f"Batch exceeded limit: {count}"

0 commit comments

Comments
 (0)