Skip to content

Commit 9213c29

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
chore: Support specifying agent_framework in Agent Engine creation and update.
PiperOrigin-RevId: 825734997
1 parent 4216790 commit 9213c29

File tree

4 files changed

+208
-9
lines changed

4 files changed

+208
-9
lines changed

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
import pytest
4040

4141

42-
_TEST_AGENT_FRAMEWORK = "test-agent-framework"
42+
_TEST_AGENT_FRAMEWORK = "google-adk"
4343

4444

4545
class CapitalizeEngine:
@@ -929,9 +929,11 @@ def test_create_agent_engine_config_with_source_packages(
929929
entrypoint_object="app",
930930
requirements_file=requirements_file_path,
931931
class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS,
932+
agent_framework=_TEST_AGENT_FRAMEWORK,
932933
)
933934
assert config["display_name"] == _TEST_AGENT_ENGINE_DISPLAY_NAME
934935
assert config["description"] == _TEST_AGENT_ENGINE_DESCRIPTION
936+
assert config["spec"]["agent_framework"] == _TEST_AGENT_FRAMEWORK
935937
assert config["spec"]["source_code_spec"] == {
936938
"inline_source": {"source_archive": "test_tarball"},
937939
"python_spec": {
@@ -1453,6 +1455,7 @@ def test_create_agent_engine_with_env_vars_dict(
14531455
entrypoint_module=None,
14541456
entrypoint_object=None,
14551457
requirements_file=None,
1458+
agent_framework=None,
14561459
)
14571460
request_mock.assert_called_with(
14581461
"post",
@@ -1539,6 +1542,7 @@ def test_create_agent_engine_with_custom_service_account(
15391542
entrypoint_module=None,
15401543
entrypoint_object=None,
15411544
requirements_file=None,
1545+
agent_framework=None,
15421546
)
15431547
request_mock.assert_called_with(
15441548
"post",
@@ -1627,6 +1631,7 @@ def test_create_agent_engine_with_experimental_mode(
16271631
entrypoint_module=None,
16281632
entrypoint_object=None,
16291633
requirements_file=None,
1634+
agent_framework=None,
16301635
)
16311636
request_mock.assert_called_with(
16321637
"post",
@@ -1779,6 +1784,7 @@ def test_create_agent_engine_with_class_methods(
17791784
entrypoint_module=None,
17801785
entrypoint_object=None,
17811786
requirements_file=None,
1787+
agent_framework=None,
17821788
)
17831789
request_mock.assert_called_with(
17841790
"post",
@@ -1798,6 +1804,92 @@ def test_create_agent_engine_with_class_methods(
17981804
None,
17991805
)
18001806

1807+
@mock.patch.object(agent_engines.AgentEngines, "_create_config")
1808+
@mock.patch.object(_agent_engines_utils, "_await_operation")
1809+
def test_create_agent_engine_with_agent_framework(
1810+
self,
1811+
mock_await_operation,
1812+
mock_create_config,
1813+
):
1814+
mock_create_config.return_value = {
1815+
"display_name": _TEST_AGENT_ENGINE_DISPLAY_NAME,
1816+
"description": _TEST_AGENT_ENGINE_DESCRIPTION,
1817+
"spec": {
1818+
"package_spec": {
1819+
"python_version": _TEST_PYTHON_VERSION,
1820+
"pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI,
1821+
"requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
1822+
},
1823+
"class_methods": [_TEST_AGENT_ENGINE_CLASS_METHOD_1],
1824+
"agent_framework": _TEST_AGENT_FRAMEWORK,
1825+
},
1826+
}
1827+
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
1828+
response=_genai_types.ReasoningEngine(
1829+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
1830+
spec=_TEST_AGENT_ENGINE_SPEC,
1831+
)
1832+
)
1833+
with mock.patch.object(
1834+
self.client.agent_engines._api_client, "request"
1835+
) as request_mock:
1836+
request_mock.return_value = genai_types.HttpResponse(body="")
1837+
self.client.agent_engines.create(
1838+
agent=self.test_agent,
1839+
config=_genai_types.AgentEngineConfig(
1840+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
1841+
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
1842+
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
1843+
staging_bucket=_TEST_STAGING_BUCKET,
1844+
agent_framework=_TEST_AGENT_FRAMEWORK,
1845+
),
1846+
)
1847+
mock_create_config.assert_called_with(
1848+
mode="create",
1849+
agent=self.test_agent,
1850+
staging_bucket=_TEST_STAGING_BUCKET,
1851+
requirements=_TEST_AGENT_ENGINE_REQUIREMENTS,
1852+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
1853+
description=None,
1854+
gcs_dir_name=None,
1855+
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
1856+
env_vars=None,
1857+
service_account=None,
1858+
context_spec=None,
1859+
psc_interface_config=None,
1860+
min_instances=None,
1861+
max_instances=None,
1862+
resource_limits=None,
1863+
container_concurrency=None,
1864+
encryption_spec=None,
1865+
labels=None,
1866+
agent_server_mode=None,
1867+
class_methods=None,
1868+
source_packages=None,
1869+
entrypoint_module=None,
1870+
entrypoint_object=None,
1871+
requirements_file=None,
1872+
agent_framework=_TEST_AGENT_FRAMEWORK,
1873+
)
1874+
request_mock.assert_called_with(
1875+
"post",
1876+
"reasoningEngines",
1877+
{
1878+
"displayName": _TEST_AGENT_ENGINE_DISPLAY_NAME,
1879+
"description": _TEST_AGENT_ENGINE_DESCRIPTION,
1880+
"spec": {
1881+
"agent_framework": _TEST_AGENT_FRAMEWORK,
1882+
"class_methods": [_TEST_AGENT_ENGINE_CLASS_METHOD_1],
1883+
"package_spec": {
1884+
"pickle_object_gcs_uri": _TEST_AGENT_ENGINE_GCS_URI,
1885+
"python_version": _TEST_PYTHON_VERSION,
1886+
"requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
1887+
},
1888+
},
1889+
},
1890+
None,
1891+
)
1892+
18011893
@pytest.mark.usefixtures("caplog")
18021894
@mock.patch.object(_agent_engines_utils, "_prepare")
18031895
@mock.patch.object(_agent_engines_utils, "_await_operation")

vertexai/_genai/_agent_engines_utils.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@
128128
_BASE_MODULES = set(_BUILTIN_MODULE_NAMES + tuple(_STDLIB_MODULE_NAMES))
129129
_BLOB_FILENAME = "agent_engine.pkl"
130130
_DEFAULT_AGENT_FRAMEWORK = "custom"
131+
_SUPPORTED_AGENT_FRAMEWORKS = frozenset(
132+
[
133+
"google-adk",
134+
"langchain",
135+
"langgraph",
136+
"ag2",
137+
"llama-index",
138+
"custom",
139+
]
140+
)
131141
_DEFAULT_ASYNC_METHOD_NAME = "async_query"
132142
_DEFAULT_ASYNC_METHOD_RETURN_TYPE = "Coroutine[Any]"
133143
_DEFAULT_ASYNC_STREAM_METHOD_NAME = "async_stream_query"
@@ -705,13 +715,35 @@ def _generate_schema(
705715
return schema
706716

707717

708-
def _get_agent_framework(*, agent: _AgentEngineInterface) -> str:
709-
if (
710-
hasattr(agent, _AGENT_FRAMEWORK_ATTR)
711-
and getattr(agent, _AGENT_FRAMEWORK_ATTR) is not None
712-
and isinstance(getattr(agent, _AGENT_FRAMEWORK_ATTR), str)
713-
):
714-
return getattr(agent, _AGENT_FRAMEWORK_ATTR)
718+
def _get_agent_framework(*, agent_framework: str, agent: _AgentEngineInterface) -> str:
719+
"""Gets the agent framework to use.
720+
721+
It prioritizes the provided `agent_framework`. If not provided or not
722+
supported, it checks the `_AGENT_FRAMEWORK_ATTR` attribute on the agent.
723+
If neither is found, it defaults to "_DEFAULT_AGENT_FRAMEWORK".
724+
725+
Args:
726+
agent_framework (str):
727+
The agent framework provided by the user.
728+
agent (_AgentEngineInterface):
729+
The agent engine instance.
730+
731+
Returns:
732+
str: The name of the agent framework to use.
733+
"""
734+
if agent_framework is not None and agent_framework in _SUPPORTED_AGENT_FRAMEWORKS:
735+
logger.info(f"Using agent framework: {agent_framework}")
736+
return agent_framework
737+
if hasattr(agent, _AGENT_FRAMEWORK_ATTR):
738+
agent_framework_attr = getattr(agent, _AGENT_FRAMEWORK_ATTR)
739+
if (
740+
agent_framework_attr is not None
741+
and isinstance(agent_framework_attr, str)
742+
and agent_framework_attr in _SUPPORTED_AGENT_FRAMEWORKS
743+
):
744+
logger.info(f"Using agent framework: {agent_framework_attr}")
745+
return agent_framework_attr
746+
logger.info(f"Using default agent framework: {_DEFAULT_AGENT_FRAMEWORK}")
715747
return _DEFAULT_AGENT_FRAMEWORK
716748

717749

vertexai/_genai/agent_engines.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def _CreateAgentEngineConfig_to_vertex(
9494
getv(from_object, ["requirements_file"]),
9595
)
9696

97+
if getv(from_object, ["agent_framework"]) is not None:
98+
setv(parent_object, ["agentFramework"], getv(from_object, ["agent_framework"]))
99+
97100
return to_object
98101

99102

@@ -285,6 +288,9 @@ def _UpdateAgentEngineConfig_to_vertex(
285288
getv(from_object, ["requirements_file"]),
286289
)
287290

291+
if getv(from_object, ["agent_framework"]) is not None:
292+
setv(parent_object, ["agentFramework"], getv(from_object, ["agent_framework"]))
293+
288294
if getv(from_object, ["update_mask"]) is not None:
289295
setv(
290296
parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"])
@@ -923,6 +929,7 @@ def create(
923929
entrypoint_module=config.entrypoint_module,
924930
entrypoint_object=config.entrypoint_object,
925931
requirements_file=config.requirements_file,
932+
agent_framework=config.agent_framework,
926933
)
927934
operation = self._create(config=api_config)
928935
# TODO: Use a more specific link.
@@ -986,6 +993,7 @@ def _create_config(
986993
entrypoint_module: Optional[str] = None,
987994
entrypoint_object: Optional[str] = None,
988995
requirements_file: Optional[str] = None,
996+
agent_framework: Optional[str] = None,
989997
) -> types.UpdateAgentEngineConfigDict:
990998
import sys
991999

@@ -1193,7 +1201,10 @@ def _create_config(
11931201
] = agent_server_mode
11941202

11951203
agent_engine_spec["agent_framework"] = (
1196-
_agent_engines_utils._get_agent_framework(agent=agent)
1204+
_agent_engines_utils._get_agent_framework(
1205+
agent_framework=agent_framework,
1206+
agent=agent,
1207+
)
11971208
)
11981209
update_masks.append("spec.agent_framework")
11991210
config["spec"] = agent_engine_spec
@@ -1421,6 +1432,7 @@ def update(
14211432
entrypoint_module=config.entrypoint_module,
14221433
entrypoint_object=config.entrypoint_object,
14231434
requirements_file=config.requirements_file,
1435+
agent_framework=config.agent_framework,
14241436
)
14251437
operation = self._update(name=name, config=api_config)
14261438
logger.info(

vertexai/_genai/types/common.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5654,6 +5654,17 @@ class CreateAgentEngineConfig(_common.BaseModel):
56545654
the source package.
56555655
""",
56565656
)
5657+
agent_framework: Optional[str] = Field(
5658+
default=None,
5659+
description="""The agent framework to be used for the Agent Engine.
5660+
The OSS agent framework used to develop the agent.
5661+
Currently supported values: "google-adk", "langchain", "langgraph",
5662+
"ag2", "llama-index", "custom".
5663+
If not specified:
5664+
- If `agent` is specified, the agent framework will be auto-detected.
5665+
- If `source_packages` is specified, the agent framework will be
5666+
default to "custom".""",
5667+
)
56575668

56585669

56595670
class CreateAgentEngineConfigDict(TypedDict, total=False):
@@ -5752,6 +5763,16 @@ class CreateAgentEngineConfigDict(TypedDict, total=False):
57525763
the source package.
57535764
"""
57545765

5766+
agent_framework: Optional[str]
5767+
"""The agent framework to be used for the Agent Engine.
5768+
The OSS agent framework used to develop the agent.
5769+
Currently supported values: "google-adk", "langchain", "langgraph",
5770+
"ag2", "llama-index", "custom".
5771+
If not specified:
5772+
- If `agent` is specified, the agent framework will be auto-detected.
5773+
- If `source_packages` is specified, the agent framework will be
5774+
default to "custom"."""
5775+
57555776

57565777
CreateAgentEngineConfigOrDict = Union[
57575778
CreateAgentEngineConfig, CreateAgentEngineConfigDict
@@ -6355,6 +6376,17 @@ class UpdateAgentEngineConfig(_common.BaseModel):
63556376
the source package.
63566377
""",
63576378
)
6379+
agent_framework: Optional[str] = Field(
6380+
default=None,
6381+
description="""The agent framework to be used for the Agent Engine.
6382+
The OSS agent framework used to develop the agent.
6383+
Currently supported values: "google-adk", "langchain", "langgraph",
6384+
"ag2", "llama-index", "custom".
6385+
If not specified:
6386+
- If `agent` is specified, the agent framework will be auto-detected.
6387+
- If `source_packages` is specified, the agent framework will be
6388+
default to "custom".""",
6389+
)
63586390
update_mask: Optional[str] = Field(
63596391
default=None,
63606392
description="""The update mask to apply. For the `FieldMask` definition, see
@@ -6458,6 +6490,16 @@ class UpdateAgentEngineConfigDict(TypedDict, total=False):
64586490
the source package.
64596491
"""
64606492

6493+
agent_framework: Optional[str]
6494+
"""The agent framework to be used for the Agent Engine.
6495+
The OSS agent framework used to develop the agent.
6496+
Currently supported values: "google-adk", "langchain", "langgraph",
6497+
"ag2", "llama-index", "custom".
6498+
If not specified:
6499+
- If `agent` is specified, the agent framework will be auto-detected.
6500+
- If `source_packages` is specified, the agent framework will be
6501+
default to "custom"."""
6502+
64616503
update_mask: Optional[str]
64626504
"""The update mask to apply. For the `FieldMask` definition, see
64636505
https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask."""
@@ -13212,6 +13254,17 @@ class AgentEngineConfig(_common.BaseModel):
1321213254
the source package.
1321313255
""",
1321413256
)
13257+
agent_framework: Optional[str] = Field(
13258+
default=None,
13259+
description="""The agent framework to be used for the Agent Engine.
13260+
The OSS agent framework used to develop the agent.
13261+
Currently supported values: "google-adk", "langchain", "langgraph",
13262+
"ag2", "llama-index", "custom".
13263+
If not specified:
13264+
- If `agent` is specified, the agent framework will be auto-detected.
13265+
- If `source_packages` is specified, the agent framework will be
13266+
default to "custom".""",
13267+
)
1321513268

1321613269

1321713270
class AgentEngineConfigDict(TypedDict, total=False):
@@ -13339,6 +13392,16 @@ class AgentEngineConfigDict(TypedDict, total=False):
1333913392
the source package.
1334013393
"""
1334113394

13395+
agent_framework: Optional[str]
13396+
"""The agent framework to be used for the Agent Engine.
13397+
The OSS agent framework used to develop the agent.
13398+
Currently supported values: "google-adk", "langchain", "langgraph",
13399+
"ag2", "llama-index", "custom".
13400+
If not specified:
13401+
- If `agent` is specified, the agent framework will be auto-detected.
13402+
- If `source_packages` is specified, the agent framework will be
13403+
default to "custom"."""
13404+
1334213405

1334313406
AgentEngineConfigOrDict = Union[AgentEngineConfig, AgentEngineConfigDict]
1334413407

0 commit comments

Comments
 (0)