Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ def get(self) -> ModelResponse:
timestamp=self.timestamp,
usage=self.usage(),
provider_name=self.provider_name,
provider_response_id=self.provider_response_id,
)

def usage(self) -> RequestUsage:
Expand All @@ -627,6 +628,11 @@ def provider_name(self) -> str | None:
"""Get the provider name."""
raise NotImplementedError()

@property
def provider_response_id(self) -> str | None:
"""Get the provider response id."""
return None

@property
@abstractmethod
def timestamp(self) -> datetime:
Expand Down
14 changes: 14 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ async def _process_streamed_response(
_response=peekable_response,
_timestamp=number_to_datetime(first_chunk.created),
_provider_name=self._provider.name,
_provider_response_id=first_chunk.id,
)

def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[chat.ChatCompletionToolParam]:
Expand Down Expand Up @@ -847,6 +848,7 @@ async def _process_streamed_response(
_response=peekable_response,
_timestamp=number_to_datetime(first_chunk.response.created_at),
_provider_name=self._provider.name,
_provider_response_id=first_chunk.response.id,
)

@overload
Expand Down Expand Up @@ -1161,6 +1163,7 @@ class OpenAIStreamedResponse(StreamedResponse):
_response: AsyncIterable[ChatCompletionChunk]
_timestamp: datetime
_provider_name: str
_provider_response_id: str

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
async for chunk in self._response:
Expand Down Expand Up @@ -1209,6 +1212,11 @@ def provider_name(self) -> str:
"""Get the provider name."""
return self._provider_name

@property
def provider_response_id(self) -> str:
"""Get the provider response id."""
return self._provider_response_id

@property
def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
Expand All @@ -1223,6 +1231,7 @@ class OpenAIResponsesStreamedResponse(StreamedResponse):
_response: AsyncIterable[responses.ResponseStreamEvent]
_timestamp: datetime
_provider_name: str
_provider_response_id: str

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
async for chunk in self._response:
Expand Down Expand Up @@ -1345,6 +1354,11 @@ def provider_name(self) -> str:
"""Get the provider name."""
return self._provider_name

@property
def provider_response_id(self) -> str:
"""Get the provider response id."""
return self._provider_response_id

@property
def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class TestModel(Model):
"""
_model_name: str = field(default='test', repr=False)
_system: str = field(default='test', repr=False)
__provider_response_id: str = field(default='resp_test', repr=False)

def __init__(
self,
Expand Down Expand Up @@ -132,6 +133,7 @@ async def request_stream(
_structured_response=model_response,
_messages=messages,
_provider_name=self._system,
_provider_response_id=self.__provider_response_id,
)

@property
Expand Down Expand Up @@ -285,6 +287,7 @@ class TestStreamedResponse(StreamedResponse):
_structured_response: ModelResponse
_messages: InitVar[Iterable[ModelMessage]]
_provider_name: str
_provider_response_id: str
_timestamp: datetime = field(default_factory=_utils.now_utc, init=False)

def __post_init__(self, _messages: Iterable[ModelMessage]):
Expand Down Expand Up @@ -327,6 +330,11 @@ def model_name(self) -> str:
"""Get the model name of the response."""
return self._model_name

@property
def provider_response_id(self) -> str:
"""Get the provider name."""
return self._provider_response_id

@property
def provider_name(self) -> str:
"""Get the provider name."""
Expand Down
14 changes: 14 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ async def ret_a(x: str) -> str:
model_name='test',
timestamp=IsNow(tz=timezone.utc),
provider_name='test',
provider_response_id='resp_test',
),
ModelRequest(
parts=[
Expand Down Expand Up @@ -98,6 +99,7 @@ async def ret_a(x: str) -> str:
model_name='test',
timestamp=IsNow(tz=timezone.utc),
provider_name='test',
provider_response_id='resp_test',
),
ModelRequest(
parts=[
Expand All @@ -112,6 +114,7 @@ async def ret_a(x: str) -> str:
model_name='test',
timestamp=IsNow(tz=timezone.utc),
provider_name='test',
provider_response_id='resp_test',
),
]
)
Expand Down Expand Up @@ -230,48 +233,55 @@ def upcase(text: str) -> str:
model_name='test',
timestamp=IsNow(tz=timezone.utc),
provider_name='test',
provider_response_id='resp_test',
),
ModelResponse(
parts=[TextPart(content='The cat ')],
usage=RequestUsage(input_tokens=51, output_tokens=2),
model_name='test',
timestamp=IsNow(tz=timezone.utc),
provider_name='test',
provider_response_id='resp_test',
),
ModelResponse(
parts=[TextPart(content='The cat sat ')],
usage=RequestUsage(input_tokens=51, output_tokens=3),
model_name='test',
timestamp=IsNow(tz=timezone.utc),
provider_name='test',
provider_response_id='resp_test',
),
ModelResponse(
parts=[TextPart(content='The cat sat on ')],
usage=RequestUsage(input_tokens=51, output_tokens=4),
model_name='test',
timestamp=IsNow(tz=timezone.utc),
provider_name='test',
provider_response_id='resp_test',
),
ModelResponse(
parts=[TextPart(content='The cat sat on the ')],
usage=RequestUsage(input_tokens=51, output_tokens=5),
model_name='test',
timestamp=IsNow(tz=timezone.utc),
provider_name='test',
provider_response_id='resp_test',
),
ModelResponse(
parts=[TextPart(content='The cat sat on the mat.')],
usage=RequestUsage(input_tokens=51, output_tokens=7),
model_name='test',
timestamp=IsNow(tz=timezone.utc),
provider_name='test',
provider_response_id='resp_test',
),
ModelResponse(
parts=[TextPart(content='The cat sat on the mat.')],
usage=RequestUsage(input_tokens=51, output_tokens=7),
model_name='test',
timestamp=IsNow(tz=timezone.utc),
provider_name='test',
provider_response_id='resp_test',
),
]
)
Expand Down Expand Up @@ -796,6 +806,7 @@ def regular_tool(x: int) -> int:
model_name='test',
timestamp=IsNow(tz=timezone.utc),
provider_name='test',
provider_response_id='resp_test',
),
ModelRequest(
parts=[
Expand All @@ -810,6 +821,7 @@ def regular_tool(x: int) -> int:
model_name='test',
timestamp=IsNow(tz=timezone.utc),
provider_name='test',
provider_response_id='resp_test',
),
ModelRequest(
parts=[
Expand Down Expand Up @@ -914,6 +926,7 @@ def output_validator_simple(data: str) -> str:
timestamp=IsNow(tz=timezone.utc),
kind='response',
provider_name='test',
provider_response_id='resp_test',
)
for text in [
'',
Expand Down Expand Up @@ -1197,6 +1210,7 @@ def my_tool(x: int) -> int:
model_name='test',
timestamp=IsDatetime(),
provider_name='test',
provider_response_id='resp_test',
)
]
)
Expand Down
1 change: 1 addition & 0 deletions tests/test_usage_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ async def ret_a(x: str) -> str:
model_name='test',
timestamp=IsNow(tz=timezone.utc),
provider_name='test',
provider_response_id='resp_test',
),
ModelRequest(
parts=[
Expand Down