|
4 | 4 | from collections.abc import AsyncIterable, AsyncIterator, Generator, Iterator |
5 | 5 | from contextlib import contextmanager |
6 | 6 | from dataclasses import dataclass, field |
| 7 | +from datetime import datetime |
7 | 8 | from typing import Any |
8 | 9 |
|
9 | 10 | import pytest |
10 | 11 | from httpx import AsyncClient |
11 | 12 | from pydantic import BaseModel |
| 13 | +from typing_extensions import Literal |
12 | 14 |
|
13 | 15 | from pydantic_ai import Agent |
14 | 16 | from pydantic_ai._run_context import RunContext |
|
29 | 31 | ToolReturnPart, |
30 | 32 | ) |
31 | 33 | from pydantic_ai.models import cached_async_http_client |
| 34 | +from pydantic_ai.models.test import TestModel |
32 | 35 |
|
33 | 36 | from .conftest import IsDatetime, IsStr |
34 | 37 |
|
@@ -1077,3 +1080,82 @@ async def test_dbos_agent_with_unserializable_deps_type(allow_model_requests: No |
1077 | 1080 | async def test_logfire_plugin(): |
1078 | 1081 | # Not a valid test for DBOS, as we don't need the LogfirePlugin. |
1079 | 1082 | pass |
| 1083 | + |
| 1084 | + |
| 1085 | +# Test dynamic toolsets in an agent with DBOS |
| 1086 | + |
| 1087 | + |
| 1088 | +@DBOS.step() |
| 1089 | +def temperature_celsius(city: str) -> float: |
| 1090 | + return 21.0 |
| 1091 | + |
| 1092 | + |
| 1093 | +@DBOS.step() |
| 1094 | +def temperature_fahrenheit(city: str) -> float: |
| 1095 | + return 69.8 |
| 1096 | + |
| 1097 | + |
| 1098 | +weather_toolset = FunctionToolset(tools=[temperature_celsius, temperature_fahrenheit]) |
| 1099 | + |
| 1100 | + |
| 1101 | +@weather_toolset.tool |
| 1102 | +@DBOS.step() |
| 1103 | +def conditions(ctx: RunContext, city: str) -> str: |
| 1104 | + if ctx.run_step % 2 == 0: |
| 1105 | + return "It's sunny" # pragma: lax no cover |
| 1106 | + else: |
| 1107 | + return "It's raining" |
| 1108 | + |
| 1109 | + |
| 1110 | +datetime_toolset = FunctionToolset() |
| 1111 | + |
| 1112 | + |
| 1113 | +@DBOS.step() |
| 1114 | +def now_func() -> datetime: |
| 1115 | + return datetime.now() |
| 1116 | + |
| 1117 | + |
| 1118 | +datetime_toolset.add_function(now_func, name='now') |
| 1119 | + |
| 1120 | + |
| 1121 | +@dataclass |
| 1122 | +class ToggleableDeps: |
| 1123 | + active: Literal['weather', 'datetime'] |
| 1124 | + |
| 1125 | + def toggle(self): |
| 1126 | + if self.active == 'weather': |
| 1127 | + self.active = 'datetime' |
| 1128 | + else: |
| 1129 | + self.active = 'weather' |
| 1130 | + |
| 1131 | + |
| 1132 | +test_model = TestModel() |
| 1133 | +dynamic_agent = Agent(name='dynamic_agent', model=test_model, deps_type=ToggleableDeps) |
| 1134 | + |
| 1135 | + |
| 1136 | +@dynamic_agent.toolset # type: ignore |
| 1137 | +def toggleable_toolset(ctx: RunContext[ToggleableDeps]) -> FunctionToolset[None]: |
| 1138 | + if ctx.deps.active == 'weather': |
| 1139 | + return weather_toolset |
| 1140 | + else: |
| 1141 | + return datetime_toolset |
| 1142 | + |
| 1143 | + |
| 1144 | +@dynamic_agent.tool |
| 1145 | +def toggle(ctx: RunContext[ToggleableDeps]): |
| 1146 | + ctx.deps.toggle() |
| 1147 | + |
| 1148 | + |
| 1149 | +dynamic_dbos_agent = DBOSAgent(dynamic_agent) |
| 1150 | + |
| 1151 | + |
| 1152 | +def test_dynamic_toolset(dbos: DBOS): |
| 1153 | + weather_deps = ToggleableDeps('weather') |
| 1154 | + |
| 1155 | + result = dynamic_dbos_agent.run_sync('Toggle the toolset', deps=weather_deps) |
| 1156 | + assert result.output == snapshot( |
| 1157 | + '{"toggle":null,"temperature_celsius":21.0,"temperature_fahrenheit":69.8,"conditions":"It\'s raining"}' |
| 1158 | + ) |
| 1159 | + |
| 1160 | + result = dynamic_dbos_agent.run_sync('Toggle the toolset', deps=weather_deps) |
| 1161 | + assert result.output == snapshot(IsStr(regex=r'{"toggle":null,"now":".+?"}')) |
0 commit comments