|
82 | 82 | 'InstrumentationSettings',
|
83 | 83 | 'WrapperAgent',
|
84 | 84 | 'AbstractAgent',
|
| 85 | + 'EventStreamHandler', |
85 | 86 | )
|
86 | 87 |
|
87 | 88 |
|
@@ -401,6 +402,11 @@ def name(self, value: str | None) -> None:
|
401 | 402 | """Set the name of the agent, used for logging."""
|
402 | 403 | self._name = value
|
403 | 404 |
|
| 405 | + @property |
| 406 | + def deps_type(self) -> type: |
| 407 | + """The type of dependencies used by the agent.""" |
| 408 | + return self._deps_type |
| 409 | + |
404 | 410 | @property
|
405 | 411 | def output_type(self) -> OutputSpec[OutputDataT]:
|
406 | 412 | """The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`."""
|
@@ -593,12 +599,7 @@ async def main():
|
593 | 599 | run_step=state.run_step,
|
594 | 600 | )
|
595 | 601 |
|
596 |
| - toolset = self._get_toolset(additional=toolsets) |
597 |
| - |
598 |
| - if output_toolset is not None: |
599 |
| - if self._prepare_output_tools: |
600 |
| - output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) |
601 |
| - toolset = CombinedToolset([output_toolset, toolset]) |
| 602 | + toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) |
602 | 603 |
|
603 | 604 | async with toolset:
|
604 | 605 | # This will raise errors for any name conflicts
|
@@ -1240,48 +1241,64 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T:
|
1240 | 1241 | return deps
|
1241 | 1242 |
|
1242 | 1243 | def _get_toolset(
|
1243 |
| - self, additional: Sequence[AbstractToolset[AgentDepsT]] | None = None |
| 1244 | + self, |
| 1245 | + output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET, |
| 1246 | + additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, |
1244 | 1247 | ) -> AbstractToolset[AgentDepsT]:
|
1245 |
| - """Get the combined toolset containing function tools registered directly to the agent and user-provided toolsets including MCP servers. |
| 1248 | + """Get the complete toolset. |
1246 | 1249 |
|
1247 | 1250 | Args:
|
1248 |
| - additional: Additional toolsets to add. |
| 1251 | + output_toolset: The output toolset to use instead of the one built at agent construction time. |
| 1252 | + additional_toolsets: Additional toolsets to add, unless toolsets have been overridden. |
1249 | 1253 | """
|
1250 |
| - if some_tools := self._override_tools.get(): |
1251 |
| - function_toolset = _AgentFunctionToolset(some_tools.value, max_retries=self._max_tool_retries) |
1252 |
| - else: |
1253 |
| - function_toolset = self._function_toolset |
| 1254 | + toolsets = self.toolsets |
| 1255 | + # Don't add additional toolsets if the toolsets have been overridden |
| 1256 | + if additional_toolsets and self._override_toolsets.get() is None: |
| 1257 | + toolsets = [*toolsets, *additional_toolsets] |
1254 | 1258 |
|
1255 |
| - if some_user_toolsets := self._override_toolsets.get(): |
1256 |
| - user_toolsets = some_user_toolsets.value |
1257 |
| - else: |
1258 |
| - # Copy the dynamic toolsets to ensure each run has its own instances |
1259 |
| - dynamic_toolsets = [dataclasses.replace(toolset) for toolset in self._dynamic_toolsets] |
1260 |
| - user_toolsets = [*self._user_toolsets, *dynamic_toolsets, *(additional or [])] |
| 1259 | + toolset = CombinedToolset(toolsets) |
1261 | 1260 |
|
1262 |
| - if user_toolsets: |
1263 |
| - toolset = CombinedToolset([function_toolset, *user_toolsets]) |
1264 |
| - else: |
1265 |
| - toolset = function_toolset |
| 1261 | + # Copy the dynamic toolsets to ensure each run has its own instances |
| 1262 | + def copy_dynamic_toolsets(toolset: AbstractToolset[AgentDepsT]) -> AbstractToolset[AgentDepsT]: |
| 1263 | + if isinstance(toolset, DynamicToolset): |
| 1264 | + return dataclasses.replace(toolset) |
| 1265 | + else: |
| 1266 | + return toolset |
| 1267 | + |
| 1268 | + toolset = toolset.visit_and_replace(copy_dynamic_toolsets) |
1266 | 1269 |
|
1267 | 1270 | if self._prepare_tools:
|
1268 | 1271 | toolset = PreparedToolset(toolset, self._prepare_tools)
|
1269 | 1272 |
|
| 1273 | + output_toolset = output_toolset if _utils.is_set(output_toolset) else self._output_toolset |
| 1274 | + if output_toolset is not None: |
| 1275 | + if self._prepare_output_tools: |
| 1276 | + output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) |
| 1277 | + toolset = CombinedToolset([output_toolset, toolset]) |
| 1278 | + |
1270 | 1279 | return toolset
|
1271 | 1280 |
|
1272 | 1281 | @property
|
1273 | 1282 | def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
|
1274 | 1283 | """All toolsets registered on the agent, including a function toolset holding tools that were registered on the agent directly.
|
1275 | 1284 |
|
1276 |
| - If a `prepare_tools` function was configured on the agent, this will contain just a `PreparedToolset` wrapping the original toolsets. |
1277 |
| -
|
1278 | 1285 | Output tools are not included.
|
1279 | 1286 | """
|
1280 |
| - toolset = self._get_toolset() |
1281 |
| - if isinstance(toolset, CombinedToolset): |
1282 |
| - return toolset.toolsets |
| 1287 | + toolsets: list[AbstractToolset[AgentDepsT]] = [] |
| 1288 | + |
| 1289 | + if some_tools := self._override_tools.get(): |
| 1290 | + function_toolset = _AgentFunctionToolset(some_tools.value, max_retries=self._max_tool_retries) |
| 1291 | + else: |
| 1292 | + function_toolset = self._function_toolset |
| 1293 | + toolsets.append(function_toolset) |
| 1294 | + |
| 1295 | + if some_user_toolsets := self._override_toolsets.get(): |
| 1296 | + user_toolsets = some_user_toolsets.value |
1283 | 1297 | else:
|
1284 |
| - return [toolset] |
| 1298 | + user_toolsets = [*self._user_toolsets, *self._dynamic_toolsets] |
| 1299 | + toolsets.extend(user_toolsets) |
| 1300 | + |
| 1301 | + return toolsets |
1285 | 1302 |
|
1286 | 1303 | def _prepare_output_schema(
|
1287 | 1304 | self, output_type: OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile
|
@@ -1369,7 +1386,7 @@ async def run_mcp_servers(
|
1369 | 1386 | class _AgentFunctionToolset(FunctionToolset[AgentDepsT]):
|
1370 | 1387 | @property
|
1371 | 1388 | def id(self) -> str:
|
1372 |
| - return '<agent>' # pragma: no cover |
| 1389 | + return '<agent>' |
1373 | 1390 |
|
1374 | 1391 | @property
|
1375 | 1392 | def label(self) -> str:
|
|
0 commit comments