-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add spec parameter to agent.override() and agent.run()
#4769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: capabilities
Are you sure you want to change the base?
Changes from 4 commits
3a9a537
fa70cc6
94ee345
65680fb
1504a8a
443a217
1cda25d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -114,6 +114,19 @@ | |
| NoneType = type(None) | ||
|
|
||
|
|
||
| @dataclasses.dataclass | ||
| class _ResolvedSpec: | ||
| """Result of resolving an AgentSpec for use at run/override time.""" | ||
|
|
||
| spec: AgentSpec | ||
| capability: CombinedCapability[Any] | None | ||
| instructions: list[Any] | ||
| model: str | None | ||
| model_settings: ModelSettings | None | ||
| metadata: dict[str, Any] | None | ||
| name: str | None | ||
|
|
||
|
|
||
| @dataclasses.dataclass(init=False) | ||
| class Agent(AbstractAgent[AgentDepsT, OutputDataT]): | ||
| """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. | ||
|
|
@@ -472,6 +485,12 @@ def __init__( | |
| self._override_model_settings: ContextVar[_utils.Option[AgentModelSettings[AgentDepsT]]] = ContextVar( | ||
| '_override_model_settings', default=None | ||
| ) | ||
| self._override_root_capability: ContextVar[_utils.Option[CombinedCapability[AgentDepsT]]] = ContextVar( | ||
| '_override_root_capability', default=None | ||
| ) | ||
| self._override_builtin_tools: ContextVar[ | ||
| _utils.Option[list[AbstractBuiltinTool | _utils.Callable[..., Any]]] | ||
| ] = ContextVar('_override_builtin_tools', default=None) | ||
|
|
||
| self._enter_lock = Lock() | ||
| self._entered_count = 0 | ||
|
|
@@ -674,8 +693,14 @@ def _instantiate_cap( | |
| if capabilities: | ||
| all_capabilities.extend(capabilities) | ||
|
|
||
| effective_model = model or validated_spec.model | ||
| if effective_model is None: | ||
| raise exceptions.UserError( | ||
| '`model` must be provided either in the spec or as a keyword argument to `from_spec()`.' | ||
| ) | ||
|
|
||
| return Agent( | ||
| model=model or validated_spec.model, | ||
| model=effective_model, | ||
| output_type=effective_output_type, | ||
| instructions=merged_instructions or None, | ||
| system_prompt=system_prompt, | ||
|
|
@@ -881,6 +906,7 @@ def iter( | |
| infer_name: bool = True, | ||
| toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, | ||
| builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, | ||
| spec: dict[str, Any] | AgentSpec | None = None, | ||
| ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... | ||
|
|
||
| @overload | ||
|
|
@@ -901,6 +927,7 @@ def iter( | |
| infer_name: bool = True, | ||
| toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, | ||
| builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, | ||
| spec: dict[str, Any] | AgentSpec | None = None, | ||
| ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... | ||
|
|
||
| @asynccontextmanager | ||
|
|
@@ -921,6 +948,7 @@ async def iter( # noqa: C901 | |
| infer_name: bool = True, | ||
| toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, | ||
| builtin_tools: Sequence[AbstractBuiltinTool | BuiltinToolFunc[AgentDepsT]] | None = None, | ||
| spec: dict[str, Any] | AgentSpec | None = None, | ||
| ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: | ||
| """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. | ||
|
|
||
|
|
@@ -995,13 +1023,49 @@ async def main(): | |
| infer_name: Whether to try to infer the agent name from the call frame if it's not set. | ||
| toolsets: Optional additional toolsets for this run. | ||
| builtin_tools: Optional additional builtin tools for this run. | ||
| spec: Optional agent spec to apply for this run. At run time, spec values are additive. | ||
|
|
||
| Returns: | ||
| The result of the run. | ||
| """ | ||
| if infer_name and self.name is None: | ||
| self._infer_name(inspect.currentframe()) | ||
|
|
||
| # Resolve spec contributions (additive at run time) | ||
| resolved = self._resolve_spec(spec) | ||
| if resolved is not None: | ||
| # Model: spec as fallback (run param > spec > agent) | ||
| if model is None and resolved.model is not None: | ||
| model = resolved.model | ||
| # Instructions: spec instructions are additional | ||
| if resolved.instructions: | ||
| extra = resolved.instructions | ||
| if instructions is not None: | ||
| existing = _instructions.normalize_instructions(instructions) | ||
| existing.extend(extra) | ||
| instructions = existing | ||
| else: | ||
| instructions = extra | ||
| # Model settings: merge spec settings under run settings (only static dicts) | ||
| if resolved.model_settings is not None: | ||
| if model_settings is None or not callable(model_settings): | ||
| model_settings = merge_model_settings(resolved.model_settings, model_settings) | ||
| # If model_settings is a callable, spec model_settings are handled via the capability layer | ||
| # Metadata: merge spec metadata under run metadata | ||
| if resolved.metadata is not None: | ||
| if metadata is not None: | ||
| if callable(metadata): | ||
| _spec_meta = resolved.metadata | ||
|
|
||
| def _merged_meta(ctx: RunContext[AgentDepsT]) -> dict[str, Any]: | ||
| return {**(_spec_meta or {}), **metadata(ctx)} # type: ignore[operator] | ||
|
|
||
| metadata = _merged_meta | ||
devin-ai-integration[bot] marked this conversation as resolved.
Show resolved
Hide resolved
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Devin correctly identified this: _orig_metadata = metadata
def _merged_meta(ctx: RunContext[AgentDepsT]) -> dict[str, Any]:
return {**(_spec_meta or {}), **_orig_metadata(ctx)} # type: ignore[operator]
metadata = _merged_meta |
||
| else: | ||
| metadata = {**resolved.metadata, **metadata} | ||
| else: | ||
| metadata = resolved.metadata | ||
|
|
||
| model_used = self._get_model(model) | ||
| del model | ||
|
|
||
|
|
@@ -1061,15 +1125,32 @@ async def main(): | |
| run_step=0, | ||
| ) | ||
|
|
||
| # Determine root capability: override > agent default | ||
| override_cap = self._override_root_capability.get() | ||
| base_capability = override_cap.value if override_cap is not None else self._root_capability | ||
|
|
||
| # Merge spec capability additively with base capability | ||
| if resolved is not None and resolved.capability is not None: | ||
| effective_capability = CombinedCapability([base_capability, resolved.capability]) | ||
| else: | ||
| effective_capability = base_capability | ||
|
|
||
| # Per-run capability: re-extract get_*() if for_run returns a different instance | ||
| run_capability = await self._root_capability.for_run(initial_ctx) | ||
| run_capability = await effective_capability.for_run(initial_ctx) | ||
| cap_toolsets: list[AgentToolset[AgentDepsT]] | None | ||
| if run_capability is not self._root_capability: | ||
| if run_capability is not effective_capability: | ||
| cap_instructions = _instructions.normalize_instructions(run_capability.get_instructions()) | ||
| cap_builtin_tools = list(run_capability.get_builtin_tools()) | ||
| cap_model_settings = run_capability.get_model_settings() | ||
| cap_ts = run_capability.get_toolset() | ||
| cap_toolsets = [cap_ts] if cap_ts is not None else [] | ||
| elif override_cap is not None or (resolved is not None and resolved.capability is not None): | ||
| # Re-extract from effective_capability since it differs from self._root_capability | ||
| cap_instructions = _instructions.normalize_instructions(effective_capability.get_instructions()) | ||
| cap_builtin_tools = list(effective_capability.get_builtin_tools()) | ||
| cap_model_settings = effective_capability.get_model_settings() | ||
| cap_ts = effective_capability.get_toolset() | ||
| cap_toolsets = [cap_ts] if cap_ts is not None else [] | ||
| else: | ||
| cap_instructions = None # use init-time defaults | ||
| cap_builtin_tools = self._cap_builtin_tools | ||
|
|
@@ -1144,7 +1225,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: | |
| output_validators=output_validators, | ||
| validation_context=self._validation_context, | ||
| root_capability=run_capability, | ||
| builtin_tools=[*self._builtin_tools, *cap_builtin_tools, *(builtin_tools or [])], | ||
| builtin_tools=[ | ||
| *self._builtin_tools, | ||
| *cap_builtin_tools, | ||
| *(override_bt.value if (override_bt := self._override_builtin_tools.get()) is not None else []), | ||
|
||
| *(builtin_tools or []), | ||
| ], | ||
devin-ai-integration[bot] marked this conversation as resolved.
Show resolved
Hide resolved
devin-ai-integration[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tool_manager=tool_manager, | ||
| tracer=tracer, | ||
| get_instructions=get_instructions, | ||
|
|
@@ -1410,6 +1496,92 @@ def _run_span_end_attributes( | |
| ), | ||
| } | ||
|
|
||
| def _resolve_spec( | ||
| self, | ||
| spec: dict[str, Any] | AgentSpec | None, | ||
| custom_capability_types: Sequence[type[AbstractCapability[Any]]] = (), | ||
| ) -> _ResolvedSpec | None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's significant code duplication between |
||
| """Validate and instantiate capabilities from a spec, returning contributions. | ||
|
|
||
| Returns None if spec is None. | ||
| """ | ||
| if spec is None: | ||
| return None | ||
|
|
||
| from pydantic_ai._spec import build_registry, load_from_registry | ||
| from pydantic_ai._template import validate_from_spec_args | ||
| from pydantic_ai.agent.spec import AgentSpec as _AgentSpecModel | ||
| from pydantic_ai.capabilities import DEFAULT_CAPABILITY_TYPES | ||
|
|
||
| template_context: dict[str, Any] = { | ||
| 'deps_type': self._deps_type if self._deps_type is not type(None) else None, | ||
| } | ||
| if isinstance(spec, dict): | ||
| validated_spec = _AgentSpecModel.model_validate(spec, context=template_context) | ||
| else: | ||
| validated_spec = spec | ||
| template_context['deps_schema'] = validated_spec.deps_schema | ||
|
|
||
| registry = build_registry( | ||
| custom_types=custom_capability_types, | ||
| defaults=DEFAULT_CAPABILITY_TYPES, | ||
| get_name=lambda c: c.get_serialization_name(), | ||
| label='capability', | ||
| ) | ||
|
|
||
| def _instantiate_cap( | ||
| cap_cls: type[AbstractCapability[Any]], | ||
| args: tuple[Any, ...], | ||
| kwargs: dict[str, Any], | ||
| ) -> AbstractCapability[Any]: | ||
| args, kwargs = validate_from_spec_args(cap_cls, args, kwargs, template_context) | ||
| return cap_cls.from_spec(*args, **kwargs) | ||
|
|
||
| capabilities: list[AbstractCapability[Any]] = [] | ||
| for cap_spec in validated_spec.capabilities: | ||
| capability = load_from_registry( | ||
| registry, | ||
| cap_spec, | ||
| label='capability', | ||
| custom_types_param='custom_capability_types', | ||
| instantiate=_instantiate_cap, | ||
| ) | ||
| capabilities.append(capability) | ||
|
|
||
| combined = CombinedCapability(capabilities) if capabilities else None | ||
|
|
||
| # Warn for unsupported fields with non-default values | ||
| _unsupported_fields = { | ||
| 'end_strategy': 'early', | ||
| 'retries': 1, | ||
| 'output_retries': None, | ||
| 'tool_timeout': None, | ||
| 'output_schema': None, | ||
| 'deps_schema': None, | ||
| } | ||
|
Comment on lines
+1496
to
+1503
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 Spec fields The Was this helpful? React with 👍 or 👎 to provide feedback. |
||
| for field_name, default_val in _unsupported_fields.items(): | ||
| val = getattr(validated_spec, field_name, default_val) | ||
| if val != default_val: | ||
| warnings.warn( | ||
| f'AgentSpec field {field_name!r} is not supported at run/override time and will be ignored', | ||
| UserWarning, | ||
| stacklevel=3, | ||
| ) | ||
|
|
||
| return _ResolvedSpec( | ||
| spec=validated_spec, | ||
| capability=combined, | ||
| instructions=_instructions.normalize_instructions(validated_spec.instructions) | ||
| if validated_spec.instructions | ||
| else [], | ||
| model=validated_spec.model, | ||
| model_settings=cast(ModelSettings, validated_spec.model_settings) | ||
| if validated_spec.model_settings | ||
| else None, | ||
| metadata=validated_spec.metadata, | ||
| name=validated_spec.name, | ||
| ) | ||
|
|
||
| @contextmanager | ||
| def override( # noqa: C901 | ||
| self, | ||
|
|
@@ -1422,6 +1594,7 @@ def override( # noqa: C901 | |
| instructions: _instructions.AgentInstructions[AgentDepsT] | _utils.Unset = _utils.UNSET, | ||
| metadata: AgentMetadata[AgentDepsT] | _utils.Unset = _utils.UNSET, | ||
| model_settings: AgentModelSettings[AgentDepsT] | _utils.Unset = _utils.UNSET, | ||
| spec: dict[str, Any] | AgentSpec | None = None, | ||
| ) -> Iterator[None]: | ||
| """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. | ||
|
|
||
|
|
@@ -1439,7 +1612,23 @@ def override( # noqa: C901 | |
| per-run `metadata` argument is ignored. | ||
| model_settings: The model settings to use instead of the model settings passed to the agent constructor. | ||
| When set, any per-run `model_settings` argument is ignored. | ||
| spec: Optional agent spec providing defaults for override. Explicit params take precedence over spec values. | ||
| """ | ||
| resolved = self._resolve_spec(spec) | ||
|
|
||
| # Apply spec values as defaults where explicit params are not set | ||
| if resolved is not None: | ||
| if not _utils.is_set(name) and resolved.name is not None: | ||
| name = resolved.name | ||
| if not _utils.is_set(model) and resolved.model is not None: | ||
| model = resolved.model | ||
| if not _utils.is_set(instructions) and resolved.instructions: | ||
| instructions = resolved.instructions | ||
| if not _utils.is_set(model_settings) and resolved.model_settings is not None: | ||
| model_settings = resolved.model_settings | ||
| if not _utils.is_set(metadata) and resolved.metadata is not None: | ||
| metadata = resolved.metadata | ||
|
Comment on lines
+1561
to
+1572
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚩 Semantic difference between There's a notable design asymmetry in how
This is a meaningful behavioral difference that could confuse users. The same asymmetry applies to model settings and metadata. This may be intentional (override = replace, run = extend), but it's worth documenting explicitly. Was this helpful? React with 👍 or 👎 to provide feedback. |
||
|
|
||
| if _utils.is_set(name): | ||
| name_token = self._override_name.set(_utils.Some(name)) | ||
| else: | ||
|
|
@@ -1481,6 +1670,15 @@ def override( # noqa: C901 | |
| else: | ||
| model_settings_token = None | ||
|
|
||
| # Set capability and builtin_tools from spec | ||
| if resolved is not None and resolved.capability is not None: | ||
| cap_token = self._override_root_capability.set(_utils.Some(resolved.capability)) | ||
devin-ai-integration[bot] marked this conversation as resolved.
Show resolved
Hide resolved
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When By contrast, the @DouweM is the replacement behavior intentional for |
||
| builtin_tools_from_cap = list(resolved.capability.get_builtin_tools()) | ||
| bt_token = self._override_builtin_tools.set(_utils.Some(builtin_tools_from_cap)) | ||
devin-ai-integration[bot] marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| else: | ||
| cap_token = None | ||
| bt_token = None | ||
|
|
||
| try: | ||
| yield | ||
| finally: | ||
|
|
@@ -1500,6 +1698,10 @@ def override( # noqa: C901 | |
| self._override_metadata.reset(metadata_token) | ||
| if model_settings_token is not None: | ||
| self._override_model_settings.reset(model_settings_token) | ||
| if cap_token is not None: | ||
| self._override_root_capability.reset(cap_token) | ||
| if bt_token is not None: | ||
| self._override_builtin_tools.reset(bt_token) | ||
|
|
||
| @overload | ||
| def instructions( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instructions: list[Any]is too loose — the type should match what_instructions.normalize_instructions()returns. Per the coding guidelines, avoidAnytype annotations; use the actual type for precision.