Skip to content

Commit bf5c295

Browse files
authored
fix IndexError when streaming OpenAI (#181)
1 parent ff7015a commit bf5c295

File tree

5 files changed

+78
-53
lines changed

5 files changed

+78
-53
lines changed

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@
3737
from groq.types import chat
3838
from groq.types.chat import ChatCompletion, ChatCompletionChunk
3939
from groq.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
40-
except ImportError as e:
40+
except ImportError as _import_error:
4141
raise ImportError(
4242
'Please install `groq` to use the Groq model, '
4343
"you can use the `groq` optional group — `pip install 'pydantic-ai[groq]'`"
44-
) from e
44+
) from _import_error
4545

4646
GroqModelName = Literal[
4747
'llama-3.1-70b-versatile',
@@ -209,33 +209,29 @@ def _process_response(response: chat.ChatCompletion) -> ModelAnyResponse:
209209
@staticmethod
210210
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
211211
"""Process a streamed response, and prepare a streaming response to return."""
212-
try:
213-
first_chunk = await response.__anext__()
214-
except StopAsyncIteration as e: # pragma: no cover
215-
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
216-
timestamp = datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)
217-
delta = first_chunk.choices[0].delta
218-
start_cost = _map_cost(first_chunk)
219-
220-
# the first chunk may only contain `role`, so we iterate until we get either `tool_calls` or `content`
221-
while delta.tool_calls is None and delta.content is None:
212+
timestamp: datetime | None = None
213+
start_cost = Cost()
214+
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
215+
while True:
222216
try:
223-
next_chunk = await response.__anext__()
217+
chunk = await response.__anext__()
224218
except StopAsyncIteration as e:
225219
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
226-
delta = next_chunk.choices[0].delta
227-
start_cost += _map_cost(next_chunk)
228-
229-
if delta.content is not None:
230-
return GroqStreamTextResponse(delta.content, response, timestamp, start_cost)
231-
else:
232-
assert delta.tool_calls is not None, f'Expected delta with tool_calls, got {delta}'
233-
return GroqStreamStructuredResponse(
234-
response,
235-
{c.index: c for c in delta.tool_calls},
236-
timestamp,
237-
start_cost,
238-
)
220+
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
221+
start_cost += _map_cost(chunk)
222+
223+
if chunk.choices:
224+
delta = chunk.choices[0].delta
225+
226+
if delta.content is not None:
227+
return GroqStreamTextResponse(delta.content, response, timestamp, start_cost)
228+
elif delta.tool_calls is not None:
229+
return GroqStreamStructuredResponse(
230+
response,
231+
{c.index: c for c in delta.tool_calls},
232+
timestamp,
233+
start_cost,
234+
)
239235

240236
@staticmethod
241237
def _map_message(message: Message) -> chat.ChatCompletionMessageParam:

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@
3737
from openai.types import ChatModel, chat
3838
from openai.types.chat import ChatCompletionChunk
3939
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
40-
except ImportError as e:
40+
except ImportError as _import_error:
4141
raise ImportError(
4242
'Please install `openai` to use the OpenAI model, '
4343
"you can use the `openai` optional group — `pip install 'pydantic-ai[openai]'`"
44-
) from e
44+
) from _import_error
4545

4646

4747
@dataclass(init=False)
@@ -189,33 +189,31 @@ def _process_response(response: chat.ChatCompletion) -> ModelAnyResponse:
189189
@staticmethod
190190
async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
191191
"""Process a streamed response, and prepare a streaming response to return."""
192-
try:
193-
first_chunk = await response.__anext__()
194-
except StopAsyncIteration as e: # pragma: no cover
195-
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
196-
timestamp = datetime.fromtimestamp(first_chunk.created, tz=timezone.utc)
197-
delta = first_chunk.choices[0].delta
198-
start_cost = _map_cost(first_chunk)
199-
200-
# the first chunk may only contain `role`, so we iterate until we get either `tool_calls` or `content`
201-
while delta.tool_calls is None and delta.content is None:
192+
timestamp: datetime | None = None
193+
start_cost = Cost()
194+
# the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content`
195+
while True:
202196
try:
203-
next_chunk = await response.__anext__()
197+
chunk = await response.__anext__()
204198
except StopAsyncIteration as e:
205199
raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
206-
delta = next_chunk.choices[0].delta
207-
start_cost += _map_cost(next_chunk)
208200

209-
if delta.content is not None:
210-
return OpenAIStreamTextResponse(delta.content, response, timestamp, start_cost)
211-
else:
212-
assert delta.tool_calls is not None, f'Expected delta with tool_calls, got {delta}'
213-
return OpenAIStreamStructuredResponse(
214-
response,
215-
{c.index: c for c in delta.tool_calls},
216-
timestamp,
217-
start_cost,
218-
)
201+
timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc)
202+
start_cost += _map_cost(chunk)
203+
204+
if chunk.choices:
205+
delta = chunk.choices[0].delta
206+
207+
if delta.content is not None:
208+
return OpenAIStreamTextResponse(delta.content, response, timestamp, start_cost)
209+
elif delta.tool_calls is not None:
210+
return OpenAIStreamStructuredResponse(
211+
response,
212+
{c.index: c for c in delta.tool_calls},
213+
timestamp,
214+
start_cost,
215+
)
216+
# else continue until we get either delta.content or delta.tool_calls
219217

220218
@staticmethod
221219
def _map_message(message: Message) -> chat.ChatCompletionMessageParam:

pydantic_ai_slim/pydantic_ai/models/vertexai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
from google.auth.credentials import Credentials as BaseCredentials
1919
from google.auth.transport.requests import Request
2020
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
21-
except ImportError as e:
21+
except ImportError as _import_error:
2222
raise ImportError(
2323
'Please install `google-auth` to use the VertexAI model, '
2424
"you can use the `vertexai` optional group — `pip install 'pydantic-ai[vertexai]'`"
25-
) from e
25+
) from _import_error
2626

2727
VERTEX_AI_URL_TEMPLATE = (
2828
'https://{region}-aiplatform.googleapis.com/v1'

tests/models/test_groq.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,3 +450,16 @@ async def test_no_content(allow_model_requests: None):
450450
with pytest.raises(UnexpectedModelBehavior, match='Streamed response ended without con'):
451451
async with agent.run_stream(''):
452452
pass # pragma: no cover
453+
454+
455+
async def test_no_delta(allow_model_requests: None):
456+
stream = chunk([]), text_chunk('hello '), text_chunk('world')
457+
mock_client = MockGroq.create_mock_stream(stream)
458+
m = GroqModel('llama-3.1-70b-versatile', groq_client=mock_client)
459+
agent = Agent(m)
460+
461+
async with agent.run_stream('') as result:
462+
assert not result.is_structured
463+
assert not result.is_complete
464+
assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world'])
465+
assert result.is_complete

tests/models/test_openai.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,3 +440,21 @@ async def test_no_content(allow_model_requests: None):
440440
with pytest.raises(UnexpectedModelBehavior, match='Streamed response ended without con'):
441441
async with agent.run_stream(''):
442442
pass
443+
444+
445+
async def test_no_delta(allow_model_requests: None):
446+
stream = (
447+
chunk([]),
448+
text_chunk('hello '),
449+
text_chunk('world'),
450+
)
451+
mock_client = MockOpenAI.create_mock_stream(stream)
452+
m = OpenAIModel('gpt-4', openai_client=mock_client)
453+
agent = Agent(m)
454+
455+
async with agent.run_stream('') as result:
456+
assert not result.is_structured
457+
assert not result.is_complete
458+
assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world'])
459+
assert result.is_complete
460+
assert result.cost() == snapshot(Cost(request_tokens=6, response_tokens=3, total_tokens=9))

0 commit comments

Comments
 (0)