|
37 | 37 | from groq.types import chat
|
38 | 38 | from groq.types.chat import ChatCompletion, ChatCompletionChunk
|
39 | 39 | from groq.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
|
40 |
| -except ImportError as e: |
| 40 | +except ImportError as _import_error: |
41 | 41 | raise ImportError(
|
42 | 42 | 'Please install `groq` to use the Groq model, '
|
43 | 43 | "you can use the `groq` optional group — `pip install 'pydantic-ai[groq]'`"
|
44 |
| - ) from e |
| 44 | + ) from _import_error |
45 | 45 |
|
46 | 46 | GroqModelName = Literal[
|
47 | 47 | 'llama-3.1-70b-versatile',
|
@@ -209,33 +209,29 @@ def _process_response(response: chat.ChatCompletion) -> ModelAnyResponse:
|
209 | 209 | @staticmethod
|
210 | 210 | async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk]) -> EitherStreamedResponse:
|
211 | 211 | """Process a streamed response, and prepare a streaming response to return."""
|
212 |
| - try: |
213 |
| - first_chunk = await response.__anext__() |
214 |
| - except StopAsyncIteration as e: # pragma: no cover |
215 |
| - raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e |
216 |
| - timestamp = datetime.fromtimestamp(first_chunk.created, tz=timezone.utc) |
217 |
| - delta = first_chunk.choices[0].delta |
218 |
| - start_cost = _map_cost(first_chunk) |
219 |
| - |
220 |
| - # the first chunk may only contain `role`, so we iterate until we get either `tool_calls` or `content` |
221 |
| - while delta.tool_calls is None and delta.content is None: |
| 212 | + timestamp: datetime | None = None |
| 213 | + start_cost = Cost() |
| 214 | + # the first chunk may contain enough information so we iterate until we get either `tool_calls` or `content` |
| 215 | + while True: |
222 | 216 | try:
|
223 |
| - next_chunk = await response.__anext__() |
| 217 | + chunk = await response.__anext__() |
224 | 218 | except StopAsyncIteration as e:
|
225 | 219 | raise UnexpectedModelBehavior('Streamed response ended without content or tool calls') from e
|
226 |
| - delta = next_chunk.choices[0].delta |
227 |
| - start_cost += _map_cost(next_chunk) |
228 |
| - |
229 |
| - if delta.content is not None: |
230 |
| - return GroqStreamTextResponse(delta.content, response, timestamp, start_cost) |
231 |
| - else: |
232 |
| - assert delta.tool_calls is not None, f'Expected delta with tool_calls, got {delta}' |
233 |
| - return GroqStreamStructuredResponse( |
234 |
| - response, |
235 |
| - {c.index: c for c in delta.tool_calls}, |
236 |
| - timestamp, |
237 |
| - start_cost, |
238 |
| - ) |
| 220 | + timestamp = timestamp or datetime.fromtimestamp(chunk.created, tz=timezone.utc) |
| 221 | + start_cost += _map_cost(chunk) |
| 222 | + |
| 223 | + if chunk.choices: |
| 224 | + delta = chunk.choices[0].delta |
| 225 | + |
| 226 | + if delta.content is not None: |
| 227 | + return GroqStreamTextResponse(delta.content, response, timestamp, start_cost) |
| 228 | + elif delta.tool_calls is not None: |
| 229 | + return GroqStreamStructuredResponse( |
| 230 | + response, |
| 231 | + {c.index: c for c in delta.tool_calls}, |
| 232 | + timestamp, |
| 233 | + start_cost, |
| 234 | + ) |
239 | 235 |
|
240 | 236 | @staticmethod
|
241 | 237 | def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
|
|
0 commit comments