Skip to content

Commit 5380352

Browse files
Jacksunweicopybara-github
authored andcommitted
refactor(config): Makes BaseAgent.from_config a final method and let sub-class to optionally override _parse_config to update kwargs if needed
This ensures that the pydantic hooks (e.g. model_validators) are triggered correctly. PiperOrigin-RevId: 791545704
1 parent e3c2bf3 commit 5380352

File tree

5 files changed

+72
-58
lines changed

5 files changed

+72
-58
lines changed

src/google/adk/agents/base_agent.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -504,20 +504,17 @@ def __set_parent_agent_for_sub_agents(self) -> BaseAgent:
504504
sub_agent.parent_agent = self
505505
return self
506506

507+
@final
507508
@classmethod
508-
@working_in_progress('BaseAgent.from_config is not ready for use.')
509509
def from_config(
510510
cls: Type[SelfAgent],
511511
config: BaseAgentConfig,
512512
config_abs_path: str,
513513
) -> SelfAgent:
514514
"""Creates an agent from a config.
515515
516-
This method converts fields in a config to the corresponding
517-
fields in an agent.
518-
519-
Child classes should re-implement this method to support loading from their
520-
custom config types.
516+
If sub-classes uses a custom agent config, override `_from_config_kwargs`
517+
method to return an updated kwargs for agent construstor.
521518
522519
Args:
523520
config: The config to create the agent from.
@@ -527,6 +524,40 @@ def from_config(
527524
Returns:
528525
The created agent.
529526
"""
527+
kwargs = cls.__create_kwargs(config, config_abs_path)
528+
kwargs = cls._parse_config(config, config_abs_path, kwargs)
529+
return cls(**kwargs)
530+
531+
@classmethod
532+
def _parse_config(
533+
cls: Type[SelfAgent],
534+
config: BaseAgentConfig,
535+
config_abs_path: str,
536+
kwargs: Dict[str, Any],
537+
) -> Dict[str, Any]:
538+
"""Parses the config and returns updated kwargs to construct the agent.
539+
540+
Sub-classes should override this method to use a custome agent config class.
541+
542+
Args:
543+
config: The config to parse.
544+
config_abs_path: The absolute path to the config file that contains the
545+
agent config.
546+
kwargs: The keyword arguments used for agent constructor.
547+
548+
Returns:
549+
The updated keyword arguments used for agent constructor.
550+
"""
551+
return kwargs
552+
553+
@classmethod
554+
def __create_kwargs(
555+
cls,
556+
config: BaseAgentConfig,
557+
config_abs_path: str,
558+
) -> Dict[str, Any]:
559+
"""Creates kwargs for the fields of BaseAgent."""
560+
530561
from .config_agent_utils import resolve_agent_reference
531562
from .config_agent_utils import resolve_callbacks
532563

@@ -549,4 +580,4 @@ def from_config(
549580
kwargs['after_agent_callback'] = resolve_callbacks(
550581
config.after_agent_callbacks
551582
)
552-
return cls(**kwargs)
583+
return kwargs

src/google/adk/agents/llm_agent.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
import importlib
1818
import inspect
1919
import logging
20-
import os
2120
from typing import Any
2221
from typing import AsyncGenerator
2322
from typing import Awaitable
2423
from typing import Callable
2524
from typing import ClassVar
25+
from typing import Dict
2626
from typing import Literal
2727
from typing import Optional
2828
from typing import Type
@@ -46,7 +46,6 @@
4646
from ..models.llm_response import LlmResponse
4747
from ..models.registry import LLMRegistry
4848
from ..planners.base_planner import BasePlanner
49-
from ..tools.agent_tool import AgentTool
5049
from ..tools.base_tool import BaseTool
5150
from ..tools.base_tool import ToolConfig
5251
from ..tools.base_toolset import BaseToolset
@@ -56,7 +55,6 @@
5655
from .base_agent import BaseAgent
5756
from .base_agent_config import BaseAgentConfig
5857
from .callback_context import CallbackContext
59-
from .common_configs import CodeConfig
6058
from .invocation_context import InvocationContext
6159
from .llm_agent_config import LlmAgentConfig
6260
from .readonly_context import ReadonlyContext
@@ -586,53 +584,55 @@ def _resolve_tools(
586584

587585
return resolved_tools
588586

589-
@classmethod
590587
@override
591-
@working_in_progress('LlmAgent.from_config is not ready for use.')
592-
def from_config(
588+
@classmethod
589+
def _parse_config(
593590
cls: Type[LlmAgent],
594591
config: LlmAgentConfig,
595592
config_abs_path: str,
596-
) -> LlmAgent:
593+
kwargs: Dict[str, Any],
594+
) -> Dict[str, Any]:
597595
from .config_agent_utils import resolve_callbacks
598596
from .config_agent_utils import resolve_code_reference
599597

600-
agent = super().from_config(config, config_abs_path)
601598
if config.model:
602-
agent.model = config.model
599+
kwargs['model'] = config.model
603600
if config.instruction:
604-
agent.instruction = config.instruction
601+
kwargs['instruction'] = config.instruction
605602
if config.disallow_transfer_to_parent:
606-
agent.disallow_transfer_to_parent = config.disallow_transfer_to_parent
603+
kwargs['disallow_transfer_to_parent'] = config.disallow_transfer_to_parent
607604
if config.disallow_transfer_to_peers:
608-
agent.disallow_transfer_to_peers = config.disallow_transfer_to_peers
605+
kwargs['disallow_transfer_to_peers'] = config.disallow_transfer_to_peers
609606
if config.include_contents != 'default':
610-
agent.include_contents = config.include_contents
607+
kwargs['include_contents'] = config.include_contents
611608
if config.input_schema:
612-
agent.input_schema = resolve_code_reference(config.input_schema)
609+
kwargs['input_schema'] = resolve_code_reference(config.input_schema)
613610
if config.output_schema:
614-
agent.output_schema = resolve_code_reference(config.output_schema)
611+
kwargs['output_schema'] = resolve_code_reference(config.output_schema)
615612
if config.output_key:
616-
agent.output_key = config.output_key
613+
kwargs['output_key'] = config.output_key
617614
if config.tools:
618-
agent.tools = cls._resolve_tools(config.tools, config_abs_path)
615+
kwargs['tools'] = cls._resolve_tools(config.tools, config_abs_path)
619616
if config.before_model_callbacks:
620-
agent.before_model_callback = resolve_callbacks(
617+
kwargs['before_model_callback'] = resolve_callbacks(
621618
config.before_model_callbacks
622619
)
623620
if config.after_model_callbacks:
624-
agent.after_model_callback = resolve_callbacks(
621+
kwargs['after_model_callback'] = resolve_callbacks(
625622
config.after_model_callbacks
626623
)
627624
if config.before_tool_callbacks:
628-
agent.before_tool_callback = resolve_callbacks(
625+
kwargs['before_tool_callback'] = resolve_callbacks(
629626
config.before_tool_callbacks
630627
)
631628
if config.after_tool_callbacks:
632-
agent.after_tool_callback = resolve_callbacks(config.after_tool_callbacks)
629+
kwargs['after_tool_callback'] = resolve_callbacks(
630+
config.after_tool_callbacks
631+
)
633632
if config.generate_content_config:
634-
agent.generate_content_config = config.generate_content_config
635-
return agent
633+
kwargs['generate_content_config'] = config.generate_content_config
634+
635+
return kwargs
636636

637637

638638
Agent: TypeAlias = LlmAgent

src/google/adk/agents/loop_agent.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
from __future__ import annotations
1818

19+
from typing import Any
1920
from typing import AsyncGenerator
2021
from typing import ClassVar
22+
from typing import Dict
2123
from typing import Optional
2224
from typing import Type
2325

@@ -74,15 +76,14 @@ async def _run_live_impl(
7476
raise NotImplementedError('This is not supported yet for LoopAgent.')
7577
yield # AsyncGenerator requires having at least one yield statement
7678

77-
@classmethod
7879
@override
79-
@working_in_progress('LoopAgent.from_config is not ready for use.')
80-
def from_config(
81-
cls: Type[LoopAgent],
80+
@classmethod
81+
def _parse_config(
82+
cls: type[LoopAgent],
8283
config: LoopAgentConfig,
8384
config_abs_path: str,
84-
) -> LoopAgent:
85-
agent = super().from_config(config, config_abs_path)
85+
kwargs: Dict[str, Any],
86+
) -> Dict[str, Any]:
8687
if config.max_iterations:
87-
agent.max_iterations = config.max_iterations
88-
return agent
88+
kwargs['max_iterations'] = config.max_iterations
89+
return kwargs

src/google/adk/agents/parallel_agent.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
from __future__ import annotations
1818

1919
import asyncio
20+
from typing import Any
2021
from typing import AsyncGenerator
2122
from typing import ClassVar
23+
from typing import Dict
2224
from typing import Type
2325

2426
from typing_extensions import override
@@ -119,13 +121,3 @@ async def _run_live_impl(
119121
) -> AsyncGenerator[Event, None]:
120122
raise NotImplementedError('This is not supported yet for ParallelAgent.')
121123
yield # AsyncGenerator requires having at least one yield statement
122-
123-
@classmethod
124-
@override
125-
@working_in_progress('ParallelAgent.from_config is not ready for use.')
126-
def from_config(
127-
cls: Type[ParallelAgent],
128-
config: ParallelAgentConfig,
129-
config_abs_path: str,
130-
) -> ParallelAgent:
131-
return super().from_config(config, config_abs_path)

src/google/adk/agents/sequential_agent.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,3 @@ def task_completed():
8181
for sub_agent in self.sub_agents:
8282
async for event in sub_agent.run_live(ctx):
8383
yield event
84-
85-
@classmethod
86-
@override
87-
@working_in_progress('SequentialAgent.from_config is not ready for use.')
88-
def from_config(
89-
cls: Type[SequentialAgent],
90-
config: SequentialAgentConfig,
91-
config_abs_path: str,
92-
) -> SequentialAgent:
93-
return super().from_config(config, config_abs_path)

0 commit comments

Comments
 (0)