Skip to content

Commit 8c107d2

Browse files
Jacksunweicopybara-github
authored andcommitted
feat(config): Adds resolveAgentClass, resolveToolInstance and resolveToolClass to ComponentRegistry for resolving the 3 type of components
Uses `resolveAgentClass` in ConfigAgentUtils and `resolveToolInstance` in LlmAgent. PiperOrigin-RevId: 795272083
1 parent e5b1fb3 commit 8c107d2

File tree

6 files changed

+203
-150
lines changed

6 files changed

+203
-150
lines changed

core/src/main/java/com/google/adk/agents/ConfigAgentUtils.java

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616

1717
package com.google.adk.agents;
1818

19-
import static com.google.common.base.Strings.isNullOrEmpty;
20-
2119
import com.fasterxml.jackson.databind.DeserializationFeature;
2220
import com.fasterxml.jackson.databind.ObjectMapper;
2321
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
22+
import com.google.adk.utils.ComponentRegistry;
2423
import java.io.File;
2524
import java.io.FileInputStream;
2625
import java.io.IOException;
@@ -59,7 +58,8 @@ public static BaseAgent fromConfig(String configPath) throws ConfigurationExcept
5958
try {
6059
// Load the base config to determine the agent class
6160
BaseAgentConfig baseConfig = loadConfigAsType(absolutePath, BaseAgentConfig.class);
62-
Class<? extends BaseAgent> agentClass = resolveAgentClass(baseConfig.agentClass());
61+
Class<? extends BaseAgent> agentClass =
62+
ComponentRegistry.resolveAgentClass(baseConfig.agentClass());
6363

6464
// Load the config file with the specific config class
6565
Class<? extends BaseAgentConfig> configClass = getConfigClassForAgent(agentClass);
@@ -97,32 +97,6 @@ private static <T extends BaseAgentConfig> T loadConfigAsType(
9797
}
9898
}
9999

100-
/**
101-
* Resolves the agent class based on the agent class name from the configuration.
102-
*
103-
* @param agentClassName the name of the agent class from the config
104-
* @return the corresponding agent class
105-
* @throws ConfigurationException if the agent class is not supported
106-
*/
107-
private static Class<? extends BaseAgent> resolveAgentClass(String agentClassName)
108-
throws ConfigurationException {
109-
// If no agent_class is specified in the yaml file, it will default to LlmAgent.
110-
if (isNullOrEmpty(agentClassName) || agentClassName.equals("LlmAgent")) {
111-
return LlmAgent.class;
112-
}
113-
114-
// TODO: Support more agent classes
115-
// Example for future extensions:
116-
// if (agentClassName.equals("CustomAgent")) {
117-
// return CustomAgent.class;
118-
// }
119-
120-
throw new ConfigurationException(
121-
"agentClass '"
122-
+ agentClassName
123-
+ "' is not supported. It must be a subclass of BaseAgent.");
124-
}
125-
126100
/**
127101
* Maps agent classes to their corresponding config classes.
128102
*

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 9 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.adk.agents;
1818

19+
import static com.google.common.base.Strings.isNullOrEmpty;
1920
import static com.google.common.base.Strings.nullToEmpty;
2021
import static java.util.stream.Collectors.joining;
2122

@@ -53,7 +54,7 @@
5354
import com.google.adk.tools.BaseTool;
5455
import com.google.adk.tools.BaseTool.ToolConfig;
5556
import com.google.adk.tools.BaseToolset;
56-
import com.google.common.base.CaseFormat;
57+
import com.google.adk.utils.ComponentRegistry;
5758
import com.google.common.base.Preconditions;
5859
import com.google.common.collect.ImmutableList;
5960
import com.google.errorprone.annotations.CanIgnoreReturnValue;
@@ -63,7 +64,6 @@
6364
import io.reactivex.rxjava3.core.Flowable;
6465
import io.reactivex.rxjava3.core.Maybe;
6566
import io.reactivex.rxjava3.core.Single;
66-
import java.lang.reflect.Constructor;
6767
import java.util.ArrayList;
6868
import java.util.List;
6969
import java.util.Map;
@@ -950,23 +950,21 @@ private static ImmutableList<BaseTool> resolveTools(
950950

951951
for (ToolConfig toolConfig : toolConfigs) {
952952
try {
953-
String toolName = toolConfig.name();
954-
if (toolName == null || toolName.trim().isEmpty()) {
953+
if (isNullOrEmpty(toolConfig.name())) {
955954
throw new ConfigurationException("Tool name cannot be empty");
956955
}
957956

958-
toolName = toolName.trim();
959-
BaseTool tool;
960-
961-
if (!toolName.contains(".")) {
962-
tool = resolveBuiltInTool(toolName, toolConfig);
957+
String toolName = toolConfig.name().trim();
958+
Optional<BaseTool> toolOpt = ComponentRegistry.resolveToolInstance(toolName);
959+
if (toolOpt.isPresent()) {
960+
resolvedTools.add(toolOpt.get());
963961
} else {
964962
// TODO: Support user-defined tools
963+
// TODO: Support using tool class via ComponentRegistry.resolveToolClass
965964
logger.debug("configAbsPath is: {}", configAbsPath);
966-
throw new ConfigurationException("User-defined tools are not yet supported");
965+
throw new ConfigurationException("Tool not found: " + toolName);
967966
}
968967

969-
resolvedTools.add(tool);
970968
logger.debug("Successfully resolved tool: {}", toolConfig.name());
971969
} catch (Exception e) {
972970
String errorMsg = "Failed to resolve tool: " + toolConfig.name();
@@ -977,77 +975,4 @@ private static ImmutableList<BaseTool> resolveTools(
977975

978976
return ImmutableList.copyOf(resolvedTools);
979977
}
980-
981-
private static BaseTool resolveBuiltInTool(String toolName, ToolConfig toolConfig)
982-
throws ConfigurationException {
983-
try {
984-
logger.debug("Resolving built-in tool: {}", toolName);
985-
// TODO: Handle built-in tool name end with Tool while config yaml file does not.
986-
// e.g.google_search in config yaml file and GoogleSearchTool in tool class name.
987-
String pascalCaseToolName = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, toolName);
988-
String className = "com.google.adk.tools." + pascalCaseToolName;
989-
Class<?> toolClass;
990-
try {
991-
toolClass = Class.forName(className);
992-
logger.debug("Successfully loaded tool class: {}", className);
993-
} catch (ClassNotFoundException e) {
994-
String fallbackClassName = "com.google.adk.tools." + toolName;
995-
try {
996-
toolClass = Class.forName(fallbackClassName);
997-
} catch (ClassNotFoundException e2) {
998-
throw new ConfigurationException(
999-
"Built-in tool not found: "
1000-
+ toolName
1001-
+ ". Expected class: "
1002-
+ className
1003-
+ " or "
1004-
+ fallbackClassName,
1005-
e2);
1006-
}
1007-
}
1008-
1009-
if (!BaseTool.class.isAssignableFrom(toolClass)) {
1010-
throw new ConfigurationException(
1011-
"Built-in tool class " + toolClass.getName() + " does not extend BaseTool");
1012-
}
1013-
1014-
@SuppressWarnings("unchecked")
1015-
Class<? extends BaseTool> baseToolClass = (Class<? extends BaseTool>) toolClass;
1016-
1017-
BaseTool tool = createToolInstance(baseToolClass, toolConfig);
1018-
logger.info(
1019-
"Successfully created built-in tool: {} (class: {})", toolName, toolClass.getName());
1020-
1021-
return tool;
1022-
1023-
} catch (Exception e) {
1024-
logger.error("Failed to create built-in tool: {}", toolName, e);
1025-
throw new ConfigurationException("Failed to create built-in tool: " + toolName, e);
1026-
}
1027-
}
1028-
1029-
private static BaseTool createToolInstance(
1030-
Class<? extends BaseTool> toolClass, ToolConfig toolConfig)
1031-
throws ConfigAgentUtils.ConfigurationException {
1032-
1033-
try {
1034-
// TODO:implement constructor with ToolArgsConfig
1035-
logger.debug("ToolConfig is: {}", toolConfig);
1036-
1037-
// Try default constructor
1038-
try {
1039-
Constructor<? extends BaseTool> constructor = toolClass.getConstructor();
1040-
return constructor.newInstance();
1041-
} catch (NoSuchMethodException e) {
1042-
// Continue
1043-
}
1044-
1045-
throw new ConfigAgentUtils.ConfigurationException(
1046-
"No suitable constructor found for tool class: " + toolClass.getName());
1047-
1048-
} catch (Exception e) {
1049-
throw new ConfigAgentUtils.ConfigurationException(
1050-
"Failed to instantiate tool class: " + toolClass.getName(), e);
1051-
}
1052-
}
1053978
}

core/src/main/java/com/google/adk/utils/ComponentRegistry.java

Lines changed: 145 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,21 @@
1616

1717
package com.google.adk.utils;
1818

19-
import com.google.adk.tools.BuiltInCodeExecutionTool;
19+
import static com.google.common.base.Strings.isNullOrEmpty;
20+
21+
import com.google.adk.agents.BaseAgent;
22+
import com.google.adk.agents.LlmAgent;
23+
import com.google.adk.agents.LoopAgent;
24+
import com.google.adk.agents.ParallelAgent;
25+
import com.google.adk.agents.SequentialAgent;
26+
import com.google.adk.tools.AgentTool;
27+
import com.google.adk.tools.BaseTool;
2028
import com.google.adk.tools.GoogleSearchTool;
29+
import com.google.adk.tools.LoadArtifactsTool;
2130
import java.util.Map;
2231
import java.util.Optional;
2332
import java.util.concurrent.ConcurrentHashMap;
33+
import javax.annotation.Nonnull;
2434
import org.slf4j.Logger;
2535
import org.slf4j.LoggerFactory;
2636

@@ -74,18 +84,44 @@ public class ComponentRegistry {
7484

7585
private final Map<String, Object> registry = new ConcurrentHashMap<>();
7686

77-
public ComponentRegistry() {
87+
protected ComponentRegistry() {
7888
initializePreWiredEntries();
7989
}
8090

8191
/** Initializes the registry with base pre-wired ADK instances. */
8292
private void initializePreWiredEntries() {
83-
registry.put("google_search", new GoogleSearchTool());
84-
registry.put("code_execution", new BuiltInCodeExecutionTool());
93+
registerAdkAgentClass(LlmAgent.class);
94+
registerAdkAgentClass(LoopAgent.class);
95+
registerAdkAgentClass(ParallelAgent.class);
96+
registerAdkAgentClass(SequentialAgent.class);
97+
98+
registerAdkToolInstance("google_search", new GoogleSearchTool());
99+
registerAdkToolInstance("load_artifacts", new LoadArtifactsTool());
100+
101+
registerAdkToolClass(AgentTool.class);
102+
// TODO: add all python tools that also exist in Java.
85103

86104
logger.debug("Initialized base pre-wired entries in ComponentRegistry");
87105
}
88106

107+
private void registerAdkAgentClass(Class<? extends BaseAgent> agentClass) {
108+
registry.put(agentClass.getName(), agentClass);
109+
// For python compatibility, also register the name used in ADK Python.
110+
registry.put("google.adk.agents." + agentClass.getSimpleName(), agentClass);
111+
}
112+
113+
private void registerAdkToolInstance(String name, @Nonnull Object toolInstance) {
114+
registry.put(name, toolInstance);
115+
// For python compatibility, also register the name used in ADK Python.
116+
registry.put("google.adk.tools." + name, toolInstance);
117+
}
118+
119+
private void registerAdkToolClass(@Nonnull Class<?> toolClass) {
120+
registry.put(toolClass.getName(), toolClass);
121+
// For python compatibility, also register the name used in ADK Python.
122+
registry.put("google.adk.tools." + toolClass.getSimpleName(), toolClass);
123+
}
124+
89125
/**
90126
* Registers an object with the given name. This can override pre-wired entries.
91127
*
@@ -182,4 +218,109 @@ public static synchronized void setInstance(ComponentRegistry newInstance) {
182218
instance = newInstance;
183219
logger.info("ComponentRegistry singleton instance updated");
184220
}
221+
222+
/**
223+
* Resolves the agent class based on the agent class name from the configuration.
224+
*
225+
* @param agentClassName the name of the agent class from the config
226+
* @return the corresponding agent class
227+
* @throws IllegalArgumentException if the agent class is not supported
228+
*/
229+
@SuppressWarnings({"unchecked", "rawtypes"}) // For type casting.
230+
public static Class<? extends BaseAgent> resolveAgentClass(String agentClassName) {
231+
// If no agent_class is specified, it will default to LlmAgent.
232+
if (isNullOrEmpty(agentClassName)) {
233+
return LlmAgent.class;
234+
}
235+
236+
ComponentRegistry registry = getInstance();
237+
238+
if (agentClassName.contains(".")) {
239+
// If agentClassName contains '.', use it directly
240+
Optional<Class> agentClass = registry.get(agentClassName, Class.class);
241+
if (agentClass.isPresent() && BaseAgent.class.isAssignableFrom(agentClass.get())) {
242+
return (Class<? extends BaseAgent>) agentClass.get();
243+
}
244+
} else {
245+
// First try the simple name
246+
Optional<Class> agentClass = registry.get(agentClassName, Class.class);
247+
if (agentClass.isPresent() && BaseAgent.class.isAssignableFrom(agentClass.get())) {
248+
return (Class<? extends BaseAgent>) agentClass.get();
249+
}
250+
251+
// If not found, try with com.google.adk.agents prefix
252+
agentClass = registry.get("com.google.adk.agents." + agentClassName, Class.class);
253+
if (agentClass.isPresent() && BaseAgent.class.isAssignableFrom(agentClass.get())) {
254+
return (Class<? extends BaseAgent>) agentClass.get();
255+
}
256+
}
257+
258+
throw new IllegalArgumentException(
259+
"agentClass '" + agentClassName + "' is not in registry or not a subclass of BaseAgent.");
260+
}
261+
262+
/**
263+
* Resolves the tool instance based on the tool name from the configuration.
264+
*
265+
* @param name the name of the tool from the config
266+
* @return an Optional containing the tool instance if found, empty otherwise
267+
*/
268+
public static Optional<BaseTool> resolveToolInstance(String name) {
269+
if (isNullOrEmpty(name)) {
270+
return Optional.empty();
271+
}
272+
273+
ComponentRegistry registry = getInstance();
274+
275+
if (name.contains(".")) {
276+
// If name contains '.', use it directly
277+
return registry.get(name, BaseTool.class);
278+
} else {
279+
// First try the simple name
280+
Optional<BaseTool> toolInstance = registry.get(name, BaseTool.class);
281+
if (toolInstance.isPresent()) {
282+
return toolInstance;
283+
}
284+
285+
// If not found, try with google.adk.tools prefix
286+
return registry.get("google.adk.tools." + name, BaseTool.class);
287+
}
288+
}
289+
290+
/**
291+
* Resolves the tool class based on the tool class name from the configuration.
292+
*
293+
* @param toolClassName the name of the tool class from the config
294+
* @return an Optional containing the tool class if found, empty otherwise
295+
*/
296+
@SuppressWarnings({"unchecked", "rawtypes"}) // For type casting.
297+
public static Optional<Class<? extends BaseTool>> resolveToolClass(String toolClassName) {
298+
if (isNullOrEmpty(toolClassName)) {
299+
return Optional.empty();
300+
}
301+
302+
ComponentRegistry registry = getInstance();
303+
304+
if (toolClassName.contains(".")) {
305+
// If toolClassName contains '.', use it directly
306+
Optional<Class> toolClass = registry.get(toolClassName, Class.class);
307+
if (toolClass.isPresent() && BaseTool.class.isAssignableFrom(toolClass.get())) {
308+
return Optional.of((Class<? extends BaseTool>) toolClass.get());
309+
}
310+
} else {
311+
// First try the simple name
312+
Optional<Class> toolClass = registry.get(toolClassName, Class.class);
313+
if (toolClass.isPresent() && BaseTool.class.isAssignableFrom(toolClass.get())) {
314+
return Optional.of((Class<? extends BaseTool>) toolClass.get());
315+
}
316+
317+
// If not found, try with google.adk.tools prefix
318+
toolClass = registry.get("google.adk.tools." + toolClassName, Class.class);
319+
if (toolClass.isPresent() && BaseTool.class.isAssignableFrom(toolClass.get())) {
320+
return Optional.of((Class<? extends BaseTool>) toolClass.get());
321+
}
322+
}
323+
324+
return Optional.empty();
325+
}
185326
}

0 commit comments

Comments
 (0)