Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions lib/crewai/src/crewai/llms/providers/openai/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def _prepare_completion_params(
}
if self.stream:
params["stream"] = self.stream
params["stream_options"] = {"include_usage": True}

params.update(self.additional_params)

Expand Down Expand Up @@ -545,6 +546,9 @@ def _handle_streaming_completion(

final_completion = stream.get_final_completion()
if final_completion and final_completion.choices:
usage = self._extract_openai_token_usage(final_completion)
self._track_token_usage_internal(usage)

parsed_result = final_completion.choices[0].message.parsed
if parsed_result:
structured_json = parsed_result.model_dump_json()
Expand All @@ -564,7 +568,11 @@ def _handle_streaming_completion(
self.client.chat.completions.create(**params)
)

usage_data: dict[str, Any] | None = None
for completion_chunk in completion_stream:
if completion_chunk.usage is not None:
usage_data = self._extract_chunk_token_usage(completion_chunk)

if not completion_chunk.choices:
continue

Expand Down Expand Up @@ -593,6 +601,9 @@ def _handle_streaming_completion(
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments

if usage_data:
self._track_token_usage_internal(usage_data)

if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
Expand Down Expand Up @@ -785,7 +796,11 @@ async def _ahandle_streaming_completion(
] = await self.async_client.chat.completions.create(**params)

accumulated_content = ""
usage_data: dict[str, Any] | None = None
async for chunk in completion_stream:
if chunk.usage is not None:
usage_data = self._extract_chunk_token_usage(chunk)

if not chunk.choices:
continue

Expand All @@ -800,6 +815,9 @@ async def _ahandle_streaming_completion(
from_agent=from_agent,
)

if usage_data:
self._track_token_usage_internal(usage_data)

try:
parsed_object = response_model.model_validate_json(accumulated_content)
structured_json = parsed_object.model_dump_json()
Expand Down Expand Up @@ -828,7 +846,11 @@ async def _ahandle_streaming_completion(
ChatCompletionChunk
] = await self.async_client.chat.completions.create(**params)

usage_data = None
async for chunk in stream:
if chunk.usage is not None:
usage_data = self._extract_chunk_token_usage(chunk)

if not chunk.choices:
continue

Expand Down Expand Up @@ -857,6 +879,9 @@ async def _ahandle_streaming_completion(
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments

if usage_data:
self._track_token_usage_internal(usage_data)

if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
Expand Down Expand Up @@ -955,6 +980,19 @@ def _extract_openai_token_usage(self, response: ChatCompletion) -> dict[str, Any
}
return {"total_tokens": 0}

def _extract_chunk_token_usage(
self, chunk: ChatCompletionChunk
) -> dict[str, Any]:
"""Extract token usage from OpenAI ChatCompletionChunk (streaming response)."""
if hasattr(chunk, "usage") and chunk.usage:
usage = chunk.usage
return {
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
"completion_tokens": getattr(usage, "completion_tokens", 0),
"total_tokens": getattr(usage, "total_tokens", 0),
}
return {"total_tokens": 0}

def _format_messages(self, messages: str | list[LLMMessage]) -> list[LLMMessage]:
"""Format messages for OpenAI API."""
base_formatted = super()._format_messages(messages)
Expand Down
130 changes: 130 additions & 0 deletions lib/crewai/tests/llms/openai/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,133 @@ def test_openai_response_format_none():

assert isinstance(result, str)
assert len(result) > 0


def test_openai_streaming_tracks_token_usage():
"""
Test that streaming mode correctly tracks token usage.
This test verifies the fix for issue #4056 where token usage was always 0
when using streaming mode.
"""
llm = LLM(model="openai/gpt-4o", stream=True)

# Create mock chunks with usage in the final chunk
mock_chunk1 = MagicMock()
mock_chunk1.choices = [MagicMock()]
mock_chunk1.choices[0].delta = MagicMock()
mock_chunk1.choices[0].delta.content = "Hello "
mock_chunk1.choices[0].delta.tool_calls = None
mock_chunk1.usage = None

mock_chunk2 = MagicMock()
mock_chunk2.choices = [MagicMock()]
mock_chunk2.choices[0].delta = MagicMock()
mock_chunk2.choices[0].delta.content = "World!"
mock_chunk2.choices[0].delta.tool_calls = None
mock_chunk2.usage = None

# Final chunk with usage information (when stream_options={"include_usage": True})
mock_chunk3 = MagicMock()
mock_chunk3.choices = []
mock_chunk3.usage = MagicMock()
mock_chunk3.usage.prompt_tokens = 10
mock_chunk3.usage.completion_tokens = 20
mock_chunk3.usage.total_tokens = 30

mock_stream = MagicMock()
mock_stream.__iter__ = MagicMock(return_value=iter([mock_chunk1, mock_chunk2, mock_chunk3]))

with patch.object(llm.client.chat.completions, "create", return_value=mock_stream):
result = llm.call("Hello")

assert result == "Hello World!"

# Verify token usage was tracked
usage = llm.get_token_usage_summary()
assert usage.prompt_tokens == 10
assert usage.completion_tokens == 20
assert usage.total_tokens == 30
assert usage.successful_requests == 1


def test_openai_streaming_with_response_model_tracks_token_usage():
"""
Test that streaming with response_model correctly tracks token usage.
This test verifies the fix for issue #4056 where token usage was always 0
when using streaming mode with response_model.
"""
from pydantic import BaseModel

class TestResponse(BaseModel):
"""Test response model."""

answer: str
confidence: float

llm = LLM(model="openai/gpt-4o", stream=True)

with patch.object(llm.client.beta.chat.completions, "stream") as mock_stream:
# Create mock chunks with content.delta event structure
mock_chunk1 = MagicMock()
mock_chunk1.type = "content.delta"
mock_chunk1.delta = '{"answer": "test", '

mock_chunk2 = MagicMock()
mock_chunk2.type = "content.delta"
mock_chunk2.delta = '"confidence": 0.95}'

# Create mock final completion with parsed result and usage
mock_parsed = TestResponse(answer="test", confidence=0.95)
mock_message = MagicMock()
mock_message.parsed = mock_parsed
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_final_completion = MagicMock()
mock_final_completion.choices = [mock_choice]
mock_final_completion.usage = MagicMock()
mock_final_completion.usage.prompt_tokens = 15
mock_final_completion.usage.completion_tokens = 25
mock_final_completion.usage.total_tokens = 40

# Create mock stream context manager
mock_stream_obj = MagicMock()
mock_stream_obj.__enter__ = MagicMock(return_value=mock_stream_obj)
mock_stream_obj.__exit__ = MagicMock(return_value=None)
mock_stream_obj.__iter__ = MagicMock(return_value=iter([mock_chunk1, mock_chunk2]))
mock_stream_obj.get_final_completion = MagicMock(return_value=mock_final_completion)

mock_stream.return_value = mock_stream_obj

result = llm.call("Test question", response_model=TestResponse)

assert result is not None

# Verify token usage was tracked
usage = llm.get_token_usage_summary()
assert usage.prompt_tokens == 15
assert usage.completion_tokens == 25
assert usage.total_tokens == 40
assert usage.successful_requests == 1


def test_openai_streaming_params_include_usage():
"""
Test that streaming mode includes stream_options with include_usage=True.
This ensures the OpenAI API will return usage information in the final chunk.
"""
llm = LLM(model="openai/gpt-4o", stream=True)

with patch.object(llm.client.chat.completions, "create") as mock_create:
mock_stream = MagicMock()
mock_stream.__iter__ = MagicMock(return_value=iter([]))
mock_create.return_value = mock_stream

try:
llm.call("Hello")
except Exception:
pass # We just want to check the call parameters

# Verify stream_options was included in the API call
call_kwargs = mock_create.call_args[1]
assert call_kwargs.get("stream") is True
assert call_kwargs.get("stream_options") == {"include_usage": True}
Loading