Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_otel_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

from typing import Literal

from pydantic import JsonValue
from typing_extensions import NotRequired, TypeAlias, TypedDict


class TextPart(TypedDict):
type: Literal['text']
content: NotRequired[str]


class ToolCallPart(TypedDict):
type: Literal['tool_call']
id: str
name: str
arguments: NotRequired[JsonValue]


class ToolCallResponsePart(TypedDict):
type: Literal['tool_call_response']
id: str
name: str
result: NotRequired[JsonValue]


class MediaUrlPart(TypedDict):
type: Literal['image-url', 'audio-url', 'video-url', 'document-url']
url: NotRequired[str]


class BinaryDataPart(TypedDict):
type: Literal['binary']
media_type: str
content: NotRequired[str]


class ThinkingPart(TypedDict):
type: Literal['thinking']
content: NotRequired[str]


MessagePart: TypeAlias = 'TextPart | ToolCallPart | ToolCallResponsePart | MediaUrlPart | BinaryDataPart | ThinkingPart'


Role = Literal['system', 'user', 'assistant']


class ChatMessage(TypedDict):
role: Role
parts: list[MessagePart]


InputMessages: TypeAlias = list[ChatMessage]


class OutputMessage(ChatMessage):
finish_reason: NotRequired[str]


OutputMessages: TypeAlias = list[OutputMessage]
110 changes: 92 additions & 18 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import TypeAlias, deprecated

from . import _utils
from . import _otel_messages, _utils
from ._utils import (
generate_tool_call_id as _generate_tool_call_id,
now_utc as _now_utc,
Expand Down Expand Up @@ -83,6 +83,9 @@ def otel_event(self, settings: InstrumentationSettings) -> Event:
body={'role': 'system', **({'content': self.content} if settings.include_content else {})},
)

def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]:
return [_otel_messages.TextPart(type='text', **{'content': self.content} if settings.include_content else {})]

__repr__ = _utils.dataclasses_no_defaults_repr


Expand Down Expand Up @@ -505,25 +508,41 @@ class UserPromptPart:
"""Part type identifier, this is available on all parts as a discriminator."""

def otel_event(self, settings: InstrumentationSettings) -> Event:
content: str | list[dict[str, Any] | str] | dict[str, Any]
if isinstance(self.content, str):
content = self.content if settings.include_content else {'kind': 'text'}
else:
content = []
for part in self.content:
if isinstance(part, str):
content.append(part if settings.include_content else {'kind': 'text'})
elif isinstance(part, (ImageUrl, AudioUrl, DocumentUrl, VideoUrl)):
content.append({'kind': part.kind, **({'url': part.url} if settings.include_content else {})})
elif isinstance(part, BinaryContent):
converted_part = {'kind': part.kind, 'media_type': part.media_type}
if settings.include_content and settings.include_binary_content:
converted_part['binary_content'] = base64.b64encode(part.data).decode()
content.append(converted_part)
else:
content.append({'kind': part.kind}) # pragma: no cover
content = [{'kind': part.pop('type'), **part} for part in self.otel_message_parts(settings)]
for part in content:
if part['kind'] == 'binary' and 'content' in part:
part['binary_content'] = part.pop('content')
content = [
part['content'] if part == {'kind': 'text', 'content': part.get('content')} else part for part in content
]
if content in ([{'kind': 'text'}], [self.content]):
content = content[0]
return Event('gen_ai.user.message', body={'content': content, 'role': 'user'})

def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]:
parts: list[_otel_messages.MessagePart] = []
content: Sequence[UserContent] = [self.content] if isinstance(self.content, str) else self.content
for part in content:
if isinstance(part, str):
parts.append(
_otel_messages.TextPart(type='text', **({'content': part} if settings.include_content else {}))
)
elif isinstance(part, (ImageUrl, AudioUrl, DocumentUrl, VideoUrl)):
parts.append(
_otel_messages.MediaUrlPart(
type=part.kind,
**{'url': part.url} if settings.include_content else {},
)
)
elif isinstance(part, BinaryContent):
converted_part = _otel_messages.BinaryDataPart(type='binary', media_type=part.media_type)
if settings.include_content and settings.include_binary_content:
converted_part['content'] = base64.b64encode(part.data).decode()
parts.append(converted_part)
else:
parts.append({'type': part.kind}) # pragma: no cover
return parts

__repr__ = _utils.dataclasses_no_defaults_repr


Expand Down Expand Up @@ -577,6 +596,18 @@ def otel_event(self, settings: InstrumentationSettings) -> Event:
},
)

def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]:
from .models.instrumented import InstrumentedModel

return [
_otel_messages.ToolCallResponsePart(
type='tool_call_response',
id=self.tool_call_id,
name=self.tool_name,
**({'result': InstrumentedModel.serialize_any(self.content)} if settings.include_content else {}),
)
]

def has_content(self) -> bool:
"""Return `True` if the tool return has content."""
return self.content is not None # pragma: no cover
Expand Down Expand Up @@ -670,6 +701,19 @@ def otel_event(self, settings: InstrumentationSettings) -> Event:
},
)

def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]:
if self.tool_name is None:
return [_otel_messages.TextPart(type='text', content=self.model_response())]
else:
return [
_otel_messages.ToolCallResponsePart(
type='tool_call_response',
id=self.tool_call_id,
name=self.tool_name,
**({'result': self.model_response()} if settings.include_content else {}),
)
]

__repr__ = _utils.dataclasses_no_defaults_repr


Expand Down Expand Up @@ -911,6 +955,36 @@ def new_event_body():

return result

def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]:
parts: list[_otel_messages.MessagePart] = []
for part in self.parts:
if isinstance(part, TextPart):
parts.append(
_otel_messages.TextPart(
type='text',
**({'content': part.content} if settings.include_content else {}),
)
)
elif isinstance(part, ThinkingPart):
parts.append(
_otel_messages.ThinkingPart(
type='thinking',
**({'content': part.content} if settings.include_content else {}),
)
)
elif isinstance(part, ToolCallPart):
call_part = _otel_messages.ToolCallPart(type='tool_call', id=part.tool_call_id, name=part.tool_name)
if settings.include_content and part.args is not None:
from .models.instrumented import InstrumentedModel

if isinstance(part.args, str):
call_part['arguments'] = part.args
else:
call_part['arguments'] = {k: InstrumentedModel.serialize_any(v) for k, v in part.args.items()}

parts.append(call_part)
return parts

@property
@deprecated('`vendor_details` is deprecated, use `provider_details` instead')
def vendor_details(self) -> dict[str, Any] | None:
Expand Down
Loading