Skip to content

Update cohere and MCP, add support for MCP ResourceLink returned from tools #2094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c1dfca6
Update tests to be compatible with new OpenAI, MistralAI and MCP vers…
medaminezghal Jun 28, 2025
8d42e69
Fix mcp compatibility
medaminezghal Jun 28, 2025
290becb
add ResourceLink type
medaminezghal Jun 28, 2025
a60e11d
add ResourceLink type
medaminezghal Jun 28, 2025
7d98900
Merge branch 'main' into update-versins
medaminezghal Jul 5, 2025
749c71d
Fix ResourceLink MCP types
medaminezghal Jul 8, 2025
7b83106
Fix
medaminezghal Jul 9, 2025
2091142
Merge branch 'main' into update-versins
medaminezghal Jul 9, 2025
6766617
Add tests
medaminezghal Jul 10, 2025
8c32055
Merge branch 'main' into update-versins
medaminezghal Jul 10, 2025
7e1c180
Fix tests fails for new cohere version
medaminezghal Jul 11, 2025
5d3112f
Merge branch 'main' into update-versins
medaminezghal Jul 11, 2025
5d05916
Merge branch 'main' into update-versins
medaminezghal Jul 17, 2025
c23076d
Revert change to cohere and fix mcp tests naming
medaminezghal Jul 17, 2025
e0e8369
Some fixes
medaminezghal Jul 17, 2025
2289561
Merge branch 'main' into update-versins
medaminezghal Jul 17, 2025
8f4f718
Update product_name.txt
medaminezghal Jul 17, 2025
7a985a6
Update test_mcp.py
medaminezghal Jul 17, 2025
bd9856e
Add compatibility with opentelemtry-api>=1.35
medaminezghal Jul 17, 2025
0f0c9de
Merge branch 'main' into update-versins
medaminezghal Jul 17, 2025
4b76833
Add @mcp.resource to ResourceLink tests tools
medaminezghal Jul 22, 2025
2cf4fcb
Fix conflicts with main
medaminezghal Jul 22, 2025
27f1a0d
Merge branch 'main' into update-versins
medaminezghal Jul 22, 2025
30b9d0d
Fix opentelemetry compatibility with version<1.35
medaminezghal Jul 23, 2025
32d0300
Fix
medaminezghal Jul 23, 2025
0b4e81a
Fix mcp_server.py
medaminezghal Jul 23, 2025
becbc57
Merge branch 'main' into update-versins
medaminezghal Jul 23, 2025
38d3b3c
Fix uv.lock
medaminezghal Jul 23, 2025
e9ab203
Fix tests
medaminezghal Jul 23, 2025
21ac93d
Revert Otel Upgrade
medaminezghal Jul 23, 2025
ce21660
Revert Otel Upgrade
medaminezghal Jul 23, 2025
1aa5c3d
Fix lint
medaminezghal Jul 23, 2025
2feaddc
Update MCP ResourceLink tests and add VCR cassettes
DouweM Jul 23, 2025
4d44f48
Fix tests
medaminezghal Jul 24, 2025
935c711
Fix coverage
medaminezghal Jul 24, 2025
de11883
Fix coverage
medaminezghal Jul 24, 2025
4ba2c0d
Fix coverage
medaminezghal Jul 24, 2025
d4a9f51
Fix coverage
medaminezghal Jul 24, 2025
411ef24
Merge branch 'main' into update-versins
medaminezghal Jul 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/logfire.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,17 +283,17 @@ Note that the OpenTelemetry Semantic Conventions are still experimental and are

### Setting OpenTelemetry SDK providers

By default, the global `TracerProvider` and `EventLoggerProvider` are used. These are set automatically by `logfire.configure()`. They can also be set by the `set_tracer_provider` and `set_event_logger_provider` functions in the OpenTelemetry Python SDK. You can set custom providers with [`InstrumentationSettings`][pydantic_ai.models.instrumented.InstrumentationSettings].
By default, the global `TracerProvider` and `LoggerProvider` are used. These are set automatically by `logfire.configure()`. They can also be set by the `set_tracer_provider` and `set_event_logger_provider` functions in the OpenTelemetry Python SDK. You can set custom providers with [`InstrumentationSettings`][pydantic_ai.models.instrumented.InstrumentationSettings].

```python {title="instrumentation_settings_providers.py"}
from opentelemetry.sdk._events import EventLoggerProvider
from opentelemetry.sdk._logs import LoggerProvider
from opentelemetry.sdk.trace import TracerProvider

from pydantic_ai.agent import Agent, InstrumentationSettings

instrumentation_settings = InstrumentationSettings(
tracer_provider=TracerProvider(),
event_logger_provider=EventLoggerProvider(),
event_logger_provider=LoggerProvider(),
)

agent = Agent('gpt-4o', instrument=instrumentation_settings)
Expand Down
5 changes: 4 additions & 1 deletion docs/mcp/server.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ async def sampling_callback(
SamplingMessage(
role='user',
content=TextContent(
type='text', text='write a poem about socks', annotations=None
type='text',
text='write a poem about socks',
annotations=None,
meta=None,
),
)
]
Expand Down
32 changes: 21 additions & 11 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def direct_call_tool(
except McpError as e:
raise exceptions.ModelRetry(e.error.message)

content = [self._map_tool_result_part(part) for part in result.content]
content = [await self._map_tool_result_part(part) for part in result.content]

if result.isError:
text = '\n'.join(str(part) for part in content)
Expand Down Expand Up @@ -258,8 +258,8 @@ async def _sampling_callback(
model=self.sampling_model.model_name,
)

def _map_tool_result_part(
self, part: mcp_types.Content
async def _map_tool_result_part(
self, part: mcp_types.ContentBlock
) -> str | messages.BinaryContent | dict[str, Any] | list[Any]:
# See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values

Expand All @@ -281,18 +281,28 @@ def _map_tool_result_part(
) # pragma: no cover
elif isinstance(part, mcp_types.EmbeddedResource):
resource = part.resource
if isinstance(resource, mcp_types.TextResourceContents):
return resource.text
elif isinstance(resource, mcp_types.BlobResourceContents):
return messages.BinaryContent(
data=base64.b64decode(resource.blob),
media_type=resource.mimeType or 'application/octet-stream',
)
return self._get_content(resource)
elif isinstance(part, mcp_types.ResourceLink):
resource_result: mcp_types.ReadResourceResult = await self._client.read_resource(part.uri)
if len(resource_result.contents) > 1:
return [self._get_content(resource) for resource in resource_result.contents]
else:
assert_never(resource)
return self._get_content(resource_result.contents[0])
else:
assert_never(part)

def _get_content(
self, resource: mcp_types.TextResourceContents | mcp_types.BlobResourceContents
) -> str | messages.BinaryContent:
if isinstance(resource, mcp_types.TextResourceContents):
return resource.text
elif isinstance(resource, mcp_types.BlobResourceContents):
return messages.BinaryContent(
data=base64.b64decode(resource.blob), media_type=resource.mimeType or 'application/octet-stream'
)
else:
assert_never(resource)


@dataclass
class MCPServerStdio(MCPServer):
Expand Down
32 changes: 16 additions & 16 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import pydantic
import pydantic_core
from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage]
from opentelemetry._logs import LogRecord # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import TypeAlias, deprecated

from . import _utils
Expand Down Expand Up @@ -76,9 +76,9 @@ class SystemPromptPart:
part_kind: Literal['system-prompt'] = 'system-prompt'
"""Part type identifier, this is available on all parts as a discriminator."""

def otel_event(self, settings: InstrumentationSettings) -> Event:
return Event(
'gen_ai.system.message',
def otel_event(self, settings: InstrumentationSettings) -> LogRecord:
return LogRecord(
event_name='gen_ai.system.message',
body={'role': 'system', **({'content': self.content} if settings.include_content else {})},
)

Expand Down Expand Up @@ -410,7 +410,7 @@ class UserPromptPart:
part_kind: Literal['user-prompt'] = 'user-prompt'
"""Part type identifier, this is available on all parts as a discriminator."""

def otel_event(self, settings: InstrumentationSettings) -> Event:
def otel_event(self, settings: InstrumentationSettings) -> LogRecord:
content: str | list[dict[str, Any] | str] | dict[str, Any]
if isinstance(self.content, str):
content = self.content if settings.include_content else {'kind': 'text'}
Expand All @@ -428,7 +428,7 @@ def otel_event(self, settings: InstrumentationSettings) -> Event:
content.append(converted_part)
else:
content.append({'kind': part.kind}) # pragma: no cover
return Event('gen_ai.user.message', body={'content': content, 'role': 'user'})
return LogRecord(event_name='gen_ai.user.message', body={'content': content, 'role': 'user'})

__repr__ = _utils.dataclasses_no_defaults_repr

Expand Down Expand Up @@ -475,9 +475,9 @@ def model_response_object(self) -> dict[str, Any]:
else:
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}

def otel_event(self, settings: InstrumentationSettings) -> Event:
return Event(
'gen_ai.tool.message',
def otel_event(self, settings: InstrumentationSettings) -> LogRecord:
return LogRecord(
event_name='gen_ai.tool.message',
body={
**({'content': self.content} if settings.include_content else {}),
'role': 'tool',
Expand Down Expand Up @@ -542,12 +542,12 @@ def model_response(self) -> str:
description = f'{len(self.content)} validation errors: {json_errors.decode()}'
return f'{description}\n\nFix the errors and try again.'

def otel_event(self, settings: InstrumentationSettings) -> Event:
def otel_event(self, settings: InstrumentationSettings) -> LogRecord:
if self.tool_name is None:
return Event('gen_ai.user.message', body={'content': self.model_response(), 'role': 'user'})
return LogRecord(event_name='gen_ai.user.message', body={'content': self.model_response(), 'role': 'user'})
else:
return Event(
'gen_ai.tool.message',
return LogRecord(
event_name='gen_ai.tool.message',
body={
**({'content': self.model_response()} if settings.include_content else {}),
'role': 'tool',
Expand Down Expand Up @@ -726,13 +726,13 @@ class ModelResponse:
vendor_id: str | None = None
"""Vendor ID as specified by the model provider. This can be used to track the specific request to the model."""

def otel_events(self, settings: InstrumentationSettings) -> list[Event]:
def otel_events(self, settings: InstrumentationSettings) -> list[LogRecord]:
"""Return OpenTelemetry events for the response."""
result: list[Event] = []
result: list[LogRecord] = []

def new_event_body():
new_body: dict[str, Any] = {'role': 'assistant'}
ev = Event('gen_ai.assistant.message', body=new_body)
ev = LogRecord(event_name='gen_ai.assistant.message', body=new_body)
result.append(ev)
return new_body

Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
ChatMessageV2,
ChatResponse,
SystemChatMessageV2,
TextAssistantMessageContentItem,
TextAssistantMessageV2ContentItem,
ToolCallV2,
ToolCallV2Function,
ToolChatMessageV2,
Expand Down Expand Up @@ -227,7 +227,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:
assert_never(item)
message_param = AssistantChatMessageV2(role='assistant')
if texts:
message_param.content = [TextAssistantMessageContentItem(text='\n\n'.join(texts))]
message_param.content = [TextAssistantMessageV2ContentItem(text='\n\n'.join(texts))]
if tool_calls:
message_param.tool_calls = tool_calls
cohere_messages.append(message_param)
Expand Down
38 changes: 20 additions & 18 deletions pydantic_ai_slim/pydantic_ai/models/instrumented.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from typing import Any, Callable, Literal
from urllib.parse import urlparse

from opentelemetry._events import (
Event, # pyright: ignore[reportPrivateImportUsage]
EventLogger, # pyright: ignore[reportPrivateImportUsage]
EventLoggerProvider, # pyright: ignore[reportPrivateImportUsage]
get_event_logger_provider, # pyright: ignore[reportPrivateImportUsage]
from opentelemetry._logs import (
Logger, # pyright: ignore[reportPrivateImportUsage]
LoggerProvider, # pyright: ignore[reportPrivateImportUsage]
LogRecord, # pyright: ignore[reportPrivateImportUsage]
get_logger_provider, # pyright: ignore[reportPrivateImportUsage]
)
from opentelemetry.metrics import MeterProvider, get_meter_provider
from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provider
Expand Down Expand Up @@ -80,7 +80,7 @@ class InstrumentationSettings:
"""

tracer: Tracer = field(repr=False)
event_logger: EventLogger = field(repr=False)
event_logger: Logger = field(repr=False)
event_mode: Literal['attributes', 'logs'] = 'attributes'
include_binary_content: bool = True

Expand All @@ -90,7 +90,7 @@ def __init__(
event_mode: Literal['attributes', 'logs'] = 'attributes',
tracer_provider: TracerProvider | None = None,
meter_provider: MeterProvider | None = None,
event_logger_provider: EventLoggerProvider | None = None,
event_logger_provider: LoggerProvider | None = None,
include_binary_content: bool = True,
include_content: bool = True,
):
Expand All @@ -117,11 +117,11 @@ def __init__(

tracer_provider = tracer_provider or get_tracer_provider()
meter_provider = meter_provider or get_meter_provider()
event_logger_provider = event_logger_provider or get_event_logger_provider()
event_logger_provider = event_logger_provider or get_logger_provider()
scope_name = 'pydantic-ai'
self.tracer = tracer_provider.get_tracer(scope_name, __version__)
self.meter = meter_provider.get_meter(scope_name, __version__)
self.event_logger = event_logger_provider.get_event_logger(scope_name, __version__)
self.event_logger = event_logger_provider.get_logger(scope_name, __version__)
self.event_mode = event_mode
self.include_binary_content = include_binary_content
self.include_content = include_content
Expand All @@ -144,7 +144,7 @@ def __init__(
**tokens_histogram_kwargs, # pyright: ignore
)

def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]:
def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[LogRecord]:
"""Convert a list of model messages to OpenTelemetry events.

Args:
Expand All @@ -153,13 +153,15 @@ def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]:
Returns:
A list of OpenTelemetry events.
"""
events: list[Event] = []
events: list[LogRecord] = []
instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage]
if instructions is not None:
events.append(Event('gen_ai.system.message', body={'content': instructions, 'role': 'system'}))
events.append(
LogRecord(event_name='gen_ai.system.message', body={'content': instructions, 'role': 'system'})
)

for message_index, message in enumerate(messages):
message_events: list[Event] = []
message_events: list[LogRecord] = []
if isinstance(message, ModelRequest):
for part in message.parts:
if hasattr(part, 'otel_event'):
Expand Down Expand Up @@ -297,8 +299,8 @@ def _record_metrics():
events = self.instrumentation_settings.messages_to_otel_events(messages)
for event in self.instrumentation_settings.messages_to_otel_events([response]):
events.append(
Event(
'gen_ai.choice',
LogRecord(
event_name='gen_ai.choice',
body={
# TODO finish_reason
'index': 0,
Expand Down Expand Up @@ -327,7 +329,7 @@ def _record_metrics():
# to prevent them from being redundantly recorded in the span itself by logfire.
record_metrics()

def _emit_events(self, span: Span, events: list[Event]) -> None:
def _emit_events(self, span: Span, events: list[LogRecord]) -> None:
if self.instrumentation_settings.event_mode == 'logs':
for event in events:
self.instrumentation_settings.event_logger.emit(event)
Expand Down Expand Up @@ -368,11 +370,11 @@ def model_attributes(model: Model):
return attributes

@staticmethod
def event_to_dict(event: Event) -> dict[str, Any]:
def event_to_dict(event: LogRecord) -> dict[str, Any]:
if not event.body:
body = {} # pragma: no cover
elif isinstance(event.body, Mapping):
body = event.body # type: ignore
body = event.body
else:
body = {'body': event.body}
return {**body, **(event.attributes or {})}
Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ dependencies = [
"pydantic>=2.10",
"pydantic-graph=={{ version }}",
"exceptiongroup; python_version < '3.11'",
"opentelemetry-api>=1.28.0",
"opentelemetry-api>=1.35.0",
"typing-inspection>=0.4.0",
]

Expand All @@ -63,7 +63,7 @@ dependencies = [
logfire = ["logfire>=3.11.0"]
# Models
openai = ["openai>=1.92.0"]
cohere = ["cohere>=5.13.11; platform_system != 'Emscripten'"]
cohere = ["cohere>=5.16.0; platform_system != 'Emscripten'"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
google = ["google-genai>=1.24.0"]
anthropic = ["anthropic>=0.52.0"]
Expand All @@ -77,7 +77,7 @@ tavily = ["tavily-python>=0.5.0"]
# CLI
cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"]
# MCP
mcp = ["mcp>=1.9.4; python_version >= '3.10'"]
mcp = ["mcp>=1.10.0; python_version >= '3.10'"]
# Evals
evals = ["pydantic-evals=={{ version }}"]
# A2A
Expand Down
1 change: 1 addition & 0 deletions tests/assets/product_name.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Pydantic AI
Loading
Loading