|
3 | 3 | from contextlib import suppress
|
4 | 4 | from copy import deepcopy
|
5 | 5 | from dataclasses import dataclass
|
| 6 | +from datetime import timedelta |
6 | 7 | from inspect import iscoroutinefunction
|
7 | 8 | from types import ModuleType, SimpleNamespace
|
8 | 9 | from typing import ClassVar, Generic, cast, overload
|
9 | 10 |
|
10 | 11 | from pydantic import BaseModel, Field
|
| 12 | +from typing_extensions import Self |
11 | 13 |
|
12 | 14 | from ragbits import agents
|
13 | 15 | from ragbits.agents.exceptions import (
|
|
20 | 22 | AgentToolNotAvailableError,
|
21 | 23 | AgentToolNotSupportedError,
|
22 | 24 | )
|
23 |
| -from ragbits.agents.mcp.server import MCPServer |
| 25 | +from ragbits.agents.mcp.server import MCPServer, MCPServerStdio, MCPServerStreamableHttp |
24 | 26 | from ragbits.agents.mcp.utils import get_tools
|
25 | 27 | from ragbits.agents.tool import Tool, ToolCallResult, ToolChoice
|
26 | 28 | from ragbits.core.audit.traces import trace
|
|
34 | 36 |
|
35 | 37 | with suppress(ImportError):
|
36 | 38 | from a2a.types import AgentCapabilities, AgentCard, AgentSkill
|
| 39 | + from pydantic_ai import Agent as PydanticAIAgent |
| 40 | + from pydantic_ai import mcp |
| 41 | + |
| 42 | + from ragbits.core.llms import LiteLLM |
37 | 43 |
|
38 | 44 |
|
39 | 45 | @dataclass
|
@@ -579,3 +585,106 @@ async def _extract_agent_skills(self) -> list["AgentSkill"]:
|
579 | 585 | )
|
580 | 586 | for tool in all_tools.values()
|
581 | 587 | ]
|
| 588 | + |
| 589 | + @requires_dependencies("pydantic_ai") |
| 590 | + def to_pydantic_ai(self) -> "PydanticAIAgent": |
| 591 | + """ |
| 592 | + Convert ragbits agent instance into a `pydantic_ai.Agent` representation. |
| 593 | +
|
| 594 | + Returns: |
| 595 | + PydanticAIAgent: The equivalent Pydantic-based agent configuration. |
| 596 | +
|
| 597 | + Raises: |
| 598 | + ValueError: If the `prompt` is not a string or a `Prompt` instance. |
| 599 | + """ |
| 600 | + mcp_servers: list[mcp.MCPServerStdio | mcp.MCPServerHTTP] = [] |
| 601 | + |
| 602 | + if not self.prompt: |
| 603 | + raise ValueError("Prompt is required but was None.") |
| 604 | + |
| 605 | + if isinstance(self.prompt, str): |
| 606 | + system_prompt = self.prompt |
| 607 | + else: |
| 608 | + if not self.prompt.system_prompt: |
| 609 | + raise ValueError("System prompt is required but was None.") |
| 610 | + system_prompt = self.prompt.system_prompt |
| 611 | + |
| 612 | + for mcp_server in self.mcp_servers: |
| 613 | + if isinstance(mcp_server, MCPServerStdio): |
| 614 | + mcp_servers.append( |
| 615 | + mcp.MCPServerStdio( |
| 616 | + command=mcp_server.params.command, args=mcp_server.params.args, env=mcp_server.params.env |
| 617 | + ) |
| 618 | + ) |
| 619 | + elif isinstance(mcp_server, MCPServerStreamableHttp): |
| 620 | + timeout = mcp_server.params["timeout"] |
| 621 | + sse_timeout = mcp_server.params["sse_read_timeout"] |
| 622 | + |
| 623 | + mcp_servers.append( |
| 624 | + mcp.MCPServerHTTP( |
| 625 | + url=mcp_server.params["url"], |
| 626 | + headers=mcp_server.params["headers"], |
| 627 | + timeout=timeout.total_seconds() if isinstance(timeout, timedelta) else timeout, |
| 628 | + sse_read_timeout=sse_timeout.total_seconds() |
| 629 | + if isinstance(sse_timeout, timedelta) |
| 630 | + else sse_timeout, |
| 631 | + ) |
| 632 | + ) |
| 633 | + return PydanticAIAgent( |
| 634 | + model=self.llm.model_name, |
| 635 | + system_prompt=system_prompt, |
| 636 | + tools=[tool.to_pydantic_ai() for tool in self.tools], |
| 637 | + mcp_servers=mcp_servers, |
| 638 | + ) |
| 639 | + |
| 640 | + @classmethod |
| 641 | + @requires_dependencies("pydantic_ai") |
| 642 | + def from_pydantic_ai(cls, pydantic_ai_agent: "PydanticAIAgent") -> Self: |
| 643 | + """ |
| 644 | + Construct an agent instance from a `pydantic_ai.Agent` representation. |
| 645 | +
|
| 646 | + Args: |
| 647 | + pydantic_ai_agent: A Pydantic-based agent configuration. |
| 648 | +
|
| 649 | + Returns: |
| 650 | + An instance of the agent class initialized from the Pydantic representation. |
| 651 | + """ |
| 652 | + mcp_servers: list[MCPServerStdio | MCPServerStreamableHttp] = [] |
| 653 | + for mcp_server in pydantic_ai_agent._mcp_servers: |
| 654 | + if isinstance(mcp_server, mcp.MCPServerStdio): |
| 655 | + mcp_servers.append( |
| 656 | + MCPServerStdio( |
| 657 | + params={ |
| 658 | + "command": mcp_server.command, |
| 659 | + "args": list(mcp_server.args), |
| 660 | + "env": mcp_server.env or {}, |
| 661 | + } |
| 662 | + ) |
| 663 | + ) |
| 664 | + elif isinstance(mcp_server, mcp.MCPServerHTTP): |
| 665 | + headers = mcp_server.headers or {} |
| 666 | + |
| 667 | + mcp_servers.append( |
| 668 | + MCPServerStreamableHttp( |
| 669 | + params={ |
| 670 | + "url": mcp_server.url, |
| 671 | + "headers": {str(k): str(v) for k, v in headers.items()}, |
| 672 | + "sse_read_timeout": mcp_server.sse_read_timeout, |
| 673 | + "timeout": mcp_server.timeout, |
| 674 | + } |
| 675 | + ) |
| 676 | + ) |
| 677 | + |
| 678 | + if not pydantic_ai_agent.model: |
| 679 | + raise ValueError("Missing LLM in `pydantic_ai.Agent` instance") |
| 680 | + elif isinstance(pydantic_ai_agent.model, str): |
| 681 | + model_name = pydantic_ai_agent.model |
| 682 | + else: |
| 683 | + model_name = pydantic_ai_agent.model.model_name |
| 684 | + |
| 685 | + return cls( |
| 686 | + llm=LiteLLM(model_name=model_name), # type: ignore[arg-type] |
| 687 | + prompt="\n".join(pydantic_ai_agent._system_prompts), |
| 688 | + tools=[tool.function for _, tool in pydantic_ai_agent._function_tools.items()], |
| 689 | + mcp_servers=cast(list[MCPServer], mcp_servers), |
| 690 | + ) |
0 commit comments