Skip to content

feat: Add id and finish_reason to ModelResponse #2325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,8 +789,40 @@ class ModelResponse:
For OpenAI models, this may include 'logprobs', 'finish_reason', etc.
"""

vendor_id: str | None = None
"""Vendor ID as specified by the model provider. This can be used to track the specific request to the model."""
id: str | None = None
"""Unique identifier for the model response, e.g. as returned by the model provider (OpenAI, etc)."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Unique identifier for the model response, e.g. as returned by the model provider (OpenAI, etc)."""
"""Unique identifier for the model response as returned by the model provider."""


finish_reason: str | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the comment below, I think this should only support the values from https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#genai-attributes. If the model returns something else, the raw value can go into vendor_details and a mapped version should go here.

"""The reason the model finished generating this response, e.g. 'stop', 'length', etc."""

@property
def vendor_id(self) -> str | None:
"""Vendor ID as specified by the model provider. This can be used to track the specific request to the model.

This is deprecated, use `id` instead.
"""
import warnings

warnings.warn('vendor_id is deprecated, use id instead', DeprecationWarning, stacklevel=2)
return self.id

@vendor_id.setter
def vendor_id(self, value: str | None) -> None:
"""Set the vendor ID.

This is deprecated, use `id` instead.
"""
import warnings

warnings.warn('vendor_id is deprecated, use id instead', DeprecationWarning, stacklevel=2)
self.id = value

def __post_init__(self) -> None:
"""Ensure vendor_details contains finish_reason for backward compatibility."""
if self.finish_reason and self.vendor_details is None:
self.vendor_details = {}
if self.finish_reason and self.vendor_details is not None:
self.vendor_details['finish_reason'] = self.finish_reason
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here can be simplifed a bit: the first if means that self.vendor_details is not None in the second if will never be the case.


def otel_events(self, settings: InstrumentationSettings) -> list[Event]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue mentions:

These fields would be used to populate gen_ai.response.id and gen_ai.response.finish_reasons in opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans#genai-attributes

Can you please handle that here as well so this PR can close that issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that genai.response.finish_reasons has specific allowed values: #1882 (comment)

I also left some other related suggestions on that older PR that tried to add finish_reason -- can you please check those out as well?

"""Return OpenTelemetry events for the response."""
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse:
)
)

return ModelResponse(items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id)
return ModelResponse(items, usage=_map_usage(response), model_name=response.model, id=response.id)

async def _process_streamed_response(self, response: AsyncStream[BetaRawMessageStreamEvent]) -> StreamedResponse:
peekable_response = _utils.PeekableAsyncStream(response)
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes
total_tokens=response['usage']['totalTokens'],
)
vendor_id = response.get('ResponseMetadata', {}).get('RequestId', None)
return ModelResponse(items, usage=u, model_name=self.model_name, vendor_id=vendor_id)
return ModelResponse(items, usage=u, model_name=self.model_name, id=vendor_id)

@overload
async def _messages_create(
Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ async def request(
if not response.usage.has_values(): # pragma: no branch
response.usage = _estimate_usage(chain(messages, [response]))
response.usage.requests = 1
response.id = getattr(response, 'id', None)
response.finish_reason = getattr(response, 'finish_reason', None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these lines necessary?

return response

@asynccontextmanager
Expand Down
11 changes: 5 additions & 6 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,16 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse:
parts = response['candidates'][0]['content']['parts']
vendor_id = response.get('vendor_id', None)
finish_reason = response['candidates'][0].get('finish_reason')
vendor_details = {}
if finish_reason:
vendor_details = {'finish_reason': finish_reason}
vendor_details['finish_reason'] = finish_reason
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the other PR that worked on this (#1882 (comment)), I think we should keep the raw value here and then add finish_reason with a mapped value. But note that GeminiModel is deprecated and we should do this in GoogleModel instead.

usage = _metadata_as_usage(response)
usage.requests = 1
return _process_response_from_parts(
parts,
response.get('model_version', self._model_name),
usage,
vendor_id=vendor_id,
id=vendor_id,
vendor_details=vendor_details,
)

Expand Down Expand Up @@ -662,7 +663,7 @@ def _process_response_from_parts(
parts: Sequence[_GeminiPartUnion],
model_name: GeminiModelName,
usage: usage.Usage,
vendor_id: str | None,
id: str | None,
vendor_details: dict[str, Any] | None = None,
) -> ModelResponse:
items: list[ModelResponsePart] = []
Expand All @@ -680,9 +681,7 @@ def _process_response_from_parts(
raise UnexpectedModelBehavior(
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
)
return ModelResponse(
parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details
)
return ModelResponse(parts=items, usage=usage, model_name=model_name, id=id, vendor_details=vendor_details)


class _GeminiFunctionCall(TypedDict):
Expand Down
8 changes: 3 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
usage = _metadata_as_usage(response)
usage.requests = 1
return _process_response_from_parts(
parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
parts, response.model_version or self._model_name, usage, id=vendor_id, vendor_details=vendor_details
)

async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse:
Expand Down Expand Up @@ -506,7 +506,7 @@ def _process_response_from_parts(
parts: list[Part],
model_name: GoogleModelName,
usage: usage.Usage,
vendor_id: str | None,
id: str | None,
vendor_details: dict[str, Any] | None = None,
) -> ModelResponse:
items: list[ModelResponsePart] = []
Expand All @@ -526,9 +526,7 @@ def _process_response_from_parts(
raise UnexpectedModelBehavior(
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
)
return ModelResponse(
parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details
)
return ModelResponse(parts=items, model_name=model_name, usage=usage, id=id, vendor_details=vendor_details)


def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
for c in choice.message.tool_calls:
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
return ModelResponse(
items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, id=response.id
)

async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _process_response(self, response: ChatCompletionOutput) -> ModelResponse:
usage=_map_usage(response),
model_name=response.model,
timestamp=timestamp,
vendor_id=response.id,
id=response.id,
)

async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _process_response(self, response: MistralChatCompletionResponse) -> ModelRes
parts.append(tool)

return ModelResponse(
parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, id=response.id
)

async def _process_streamed_response(
Expand Down
5 changes: 3 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,8 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons
model_name=response.model,
timestamp=timestamp,
vendor_details=vendor_details,
vendor_id=response.id,
id=response.id,
finish_reason=choice.finish_reason,
)

async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
Expand Down Expand Up @@ -706,7 +707,7 @@ def _process_response(self, response: responses.Response) -> ModelResponse:
items,
usage=_map_usage(response),
model_name=response.model,
vendor_id=response.id,
id=response.id,
timestamp=timestamp,
)

Expand Down
31 changes: 26 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,23 +227,44 @@ def _request(
output[part.tool_name] = part.content
if output:
return ModelResponse(
parts=[TextPart(pydantic_core.to_json(output).decode())], model_name=self._model_name
parts=[TextPart(pydantic_core.to_json(output).decode())],
model_name=self._model_name,
id=None,
finish_reason=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As these fields are optional, we shouldn't need to include them here (and below)

)
else:
return ModelResponse(parts=[TextPart('success (no tool calls)')], model_name=self._model_name)
return ModelResponse(
parts=[TextPart('success (no tool calls)')],
model_name=self._model_name,
id=None,
finish_reason=None,
)
else:
return ModelResponse(parts=[TextPart(response_text)], model_name=self._model_name)
return ModelResponse(
parts=[TextPart(response_text)],
model_name=self._model_name,
id=None,
finish_reason=None,
)
else:
assert output_tools, 'No output tools provided'
custom_output_args = output_wrapper.value
output_tool = output_tools[self.seed % len(output_tools)]
if custom_output_args is not None:
return ModelResponse(
parts=[ToolCallPart(output_tool.name, custom_output_args)], model_name=self._model_name
parts=[ToolCallPart(output_tool.name, custom_output_args)],
model_name=self._model_name,
id=None,
finish_reason=None,
)
else:
response_args = self.gen_tool_args(output_tool)
return ModelResponse(parts=[ToolCallPart(output_tool.name, response_args)], model_name=self._model_name)
return ModelResponse(
parts=[ToolCallPart(output_tool.name, response_args)],
model_name=self._model_name,
id=None,
finish_reason=None,
)


@dataclass
Expand Down
Loading
Loading