Skip to content

Commit 183303e

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Using reflection for LlmRegistry and ComponentRegistry
PiperOrigin-RevId: 855224096
1 parent c6c9557 commit 183303e

File tree

3 files changed

+109
-69
lines changed

3 files changed

+109
-69
lines changed

core/src/main/java/com/google/adk/models/LlmRegistry.java

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ public interface LlmFactory {
3434
/** Map of model name patterns regex to factories. */
3535
private static final Map<String, LlmFactory> llmFactories = new ConcurrentHashMap<>();
3636

37-
/** Registers default LLM factories, e.g. for Gemini models. */
37+
/* Registers default LLM factories, e.g. for Gemini models. */
3838
static {
39-
registerLlm("gemini-.*", modelName -> Gemini.builder().modelName(modelName).build());
40-
registerLlm("apigee/.*", modelName -> ApigeeLlm.builder().modelName(modelName).build());
39+
registerViaReflection("com.google.adk.models.Gemini", "gemini-.*");
40+
registerViaReflection("com.google.adk.models.ApigeeLlm", "apigee/.*");
4141
}
4242

4343
/**
@@ -78,6 +78,31 @@ private static BaseLlm createLlm(String modelName) {
7878
throw new IllegalArgumentException("Unsupported model: " + modelName);
7979
}
8080

81+
/**
82+
* Registers an LLM factory via reflection, if the class is available.
83+
*
84+
* @param className The fully qualified class name of the LLM.
85+
* @param pattern The regex pattern for matching model names.
86+
*/
87+
private static void registerViaReflection(String className, String pattern) {
88+
try {
89+
Class<?> llmClass = Class.forName(className);
90+
LlmFactory factory =
91+
modelName -> {
92+
try {
93+
Object builder = llmClass.getMethod("builder").invoke(null);
94+
builder.getClass().getMethod("modelName", String.class).invoke(builder, modelName);
95+
return (BaseLlm) builder.getClass().getMethod("build").invoke(builder);
96+
} catch (ReflectiveOperationException e) {
97+
throw new IllegalArgumentException("Failed to create instance of " + className, e);
98+
}
99+
};
100+
registerLlm(pattern, factory);
101+
} catch (ClassNotFoundException e) {
102+
// ignore - LLM not available.
103+
}
104+
}
105+
81106
/**
82107
* Registers an LLM factory for testing purposes. Clears cached instances matching the given
83108
* pattern to ensure test isolation.

core/src/main/java/com/google/adk/utils/ComponentRegistry.java

Lines changed: 80 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,13 @@
2222

2323
import com.google.adk.agents.BaseAgent;
2424
import com.google.adk.agents.Callbacks;
25-
import com.google.adk.agents.LlmAgent;
26-
import com.google.adk.agents.LoopAgent;
27-
import com.google.adk.agents.ParallelAgent;
28-
import com.google.adk.agents.SequentialAgent;
29-
import com.google.adk.tools.AgentTool;
3025
import com.google.adk.tools.BaseTool;
3126
import com.google.adk.tools.BaseToolset;
32-
import com.google.adk.tools.ExampleTool;
33-
import com.google.adk.tools.ExitLoopTool;
34-
import com.google.adk.tools.GoogleMapsTool;
35-
import com.google.adk.tools.GoogleSearchTool;
36-
import com.google.adk.tools.LoadArtifactsTool;
37-
import com.google.adk.tools.LongRunningFunctionTool;
38-
import com.google.adk.tools.UrlContextTool;
39-
import com.google.adk.tools.mcp.McpToolset;
27+
import com.google.common.collect.ImmutableMap;
4028
import java.util.Map;
4129
import java.util.Optional;
4230
import java.util.Set;
4331
import java.util.concurrent.ConcurrentHashMap;
44-
import javax.annotation.Nonnull;
4532
import org.slf4j.Logger;
4633
import org.slf4j.LoggerFactory;
4734

@@ -93,6 +80,8 @@
9380
public class ComponentRegistry {
9481

9582
private static final Logger logger = LoggerFactory.getLogger(ComponentRegistry.class);
83+
private static volatile ImmutableMap<String, Object> DEFAULT_REGISTRY;
84+
9685
private static volatile ComponentRegistry instance = new ComponentRegistry();
9786

9887
private final Map<String, Object> registry = new ConcurrentHashMap<>();
@@ -103,55 +92,80 @@ protected ComponentRegistry() {
10392

10493
/** Initializes the registry with base pre-wired ADK instances. */
10594
private void initializePreWiredEntries() {
106-
registerAdkAgentClass(LlmAgent.class);
107-
registerAdkAgentClass(LoopAgent.class);
108-
registerAdkAgentClass(ParallelAgent.class);
109-
registerAdkAgentClass(SequentialAgent.class);
110-
111-
registerAdkToolInstance("google_search", GoogleSearchTool.INSTANCE);
112-
registerAdkToolInstance("load_artifacts", LoadArtifactsTool.INSTANCE);
113-
registerAdkToolInstance("exit_loop", ExitLoopTool.INSTANCE);
114-
registerAdkToolInstance("url_context", UrlContextTool.INSTANCE);
115-
registerAdkToolInstance("google_maps_grounding", GoogleMapsTool.INSTANCE);
116-
117-
registerAdkToolClass(AgentTool.class);
118-
registerAdkToolClass(LongRunningFunctionTool.class);
119-
registerAdkToolClass(ExampleTool.class);
120-
121-
registerAdkToolsetClass(McpToolset.class);
122-
// TODO: add all python tools that also exist in Java.
123-
95+
if (DEFAULT_REGISTRY == null) {
96+
synchronized (ComponentRegistry.class) {
97+
if (DEFAULT_REGISTRY == null) {
98+
registerAdkClassByName("com.google.adk.agents.LlmAgent");
99+
registerAdkClassByName("com.google.adk.agents.LoopAgent");
100+
registerAdkClassByName("com.google.adk.agents.ParallelAgent");
101+
registerAdkClassByName("com.google.adk.agents.SequentialAgent");
102+
103+
registerAdkToolInstance("google_search", "com.google.adk.tools.GoogleSearchTool");
104+
registerAdkToolInstance("load_artifacts", "com.google.adk.tools.LoadArtifactsTool");
105+
registerAdkToolInstance("exit_loop", "com.google.adk.tools.ExitLoopTool");
106+
registerAdkToolInstance("url_context", "com.google.adk.tools.UrlContextTool");
107+
registerAdkToolInstance("google_maps_grounding", "com.google.adk.tools.GoogleMapsTool");
108+
109+
registerAdkClassByName("com.google.adk.tools.AgentTool");
110+
registerAdkClassByName("com.google.adk.tools.LongRunningFunctionTool");
111+
registerAdkClassByName("com.google.adk.tools.ExampleTool");
112+
113+
registerAdkClassByName("com.google.adk.tools.mcp.McpToolset");
114+
// TODO: add all python tools that also exist in Java.
115+
116+
DEFAULT_REGISTRY = ImmutableMap.copyOf(registry);
117+
return;
118+
}
119+
}
120+
}
121+
registry.putAll(DEFAULT_REGISTRY);
124122
logger.debug("Initialized base pre-wired entries in ComponentRegistry");
125123
}
126124

127-
private void registerAdkAgentClass(Class<? extends BaseAgent> agentClass) {
128-
registry.put(agentClass.getName(), agentClass);
129-
// For python compatibility, also register the name used in ADK Python.
130-
registry.put("google.adk.agents." + agentClass.getSimpleName(), agentClass);
131-
}
125+
private void registerAdkClassByName(String className) {
126+
try {
127+
Class<?> clazz = Thread.currentThread().getContextClassLoader().loadClass(className);
128+
String standardPrefix;
129+
if (BaseAgent.class.isAssignableFrom(clazz)) {
130+
standardPrefix = "google.adk.agents.";
131+
} else if (BaseTool.class.isAssignableFrom(clazz)) {
132+
standardPrefix = "google.adk.tools.";
133+
} else if (BaseToolset.class.isAssignableFrom(clazz)) {
134+
standardPrefix = "google.adk.tools.";
135+
} else {
136+
throw new IllegalArgumentException(
137+
"Cannot determine standardPrefix for type " + clazz.getName());
138+
}
132139

133-
private void registerAdkToolInstance(String name, @Nonnull Object toolInstance) {
134-
registry.put(name, toolInstance);
135-
// For python compatibility, also register the name used in ADK Python.
136-
registry.put("google.adk.tools." + name, toolInstance);
137-
}
140+
registry.put(clazz.getName(), clazz);
141+
// For python compatibility, also register the name used in ADK Python.
142+
registry.put(standardPrefix + clazz.getSimpleName(), clazz);
138143

139-
private void registerAdkToolClass(@Nonnull Class<?> toolClass) {
140-
registry.put(toolClass.getName(), toolClass);
141-
// For python compatibility, also register the name used in ADK Python.
142-
registry.put("google.adk.tools." + toolClass.getSimpleName(), toolClass);
144+
if (BaseToolset.class.isAssignableFrom(clazz)) {
145+
registry.put(clazz.getSimpleName(), clazz);
146+
if (clazz.getSimpleName().equals("McpToolset")) {
147+
registry.put("mcp.McpToolset", clazz);
148+
}
149+
}
150+
} catch (Exception e) {
151+
logger.info(
152+
"{} not found, skipping registration: {}",
153+
className.substring(className.lastIndexOf('.') + 1),
154+
e.getMessage());
155+
}
143156
}
144157

145-
private void registerAdkToolsetClass(@Nonnull Class<? extends BaseToolset> toolsetClass) {
146-
registry.put(toolsetClass.getName(), toolsetClass);
147-
// For python compatibility, also register the name used in ADK Python.
148-
registry.put("google.adk.tools." + toolsetClass.getSimpleName(), toolsetClass);
149-
// Also register by simple class name
150-
registry.put(toolsetClass.getSimpleName(), toolsetClass);
151-
// Special support for toolsets with various naming conventions
152-
String simpleName = toolsetClass.getSimpleName();
153-
if (simpleName.equals("McpToolset")) {
154-
registry.put("mcp.McpToolset", toolsetClass);
158+
private void registerAdkToolInstance(String name, String toolClassName) {
159+
try {
160+
Object toolInstance = Class.forName(toolClassName).getField("INSTANCE").get(null);
161+
registry.put(name, toolInstance);
162+
// For python compatibility, also register the name used in ADK Python.
163+
registry.put("google.adk.tools." + name, toolInstance);
164+
} catch (Exception e) {
165+
logger.info(
166+
"{} not found, skipping registration: {}",
167+
toolClassName.substring(toolClassName.lastIndexOf('.') + 1),
168+
e.getMessage());
155169
}
156170
}
157171

@@ -281,30 +295,31 @@ public static Optional<BaseAgent> resolveAgentInstance(String name) {
281295
*/
282296
public static Class<? extends BaseAgent> resolveAgentClass(String agentClassName) {
283297
// If no agent_class is specified, it will default to LlmAgent.
284-
if (isNullOrEmpty(agentClassName)) {
285-
return LlmAgent.class;
286-
}
298+
final String effectiveAgentClassName =
299+
isNullOrEmpty(agentClassName) ? "com.google.adk.agents.LlmAgent" : agentClassName;
287300

288301
Optional<Class<? extends BaseAgent>> agentClass;
289302

290-
if (agentClassName.contains(".")) {
303+
if (effectiveAgentClassName.contains(".")) {
291304
// If agentClassName contains '.', use it directly
292-
agentClass = getType(agentClassName, BaseAgent.class);
305+
agentClass = getType(effectiveAgentClassName, BaseAgent.class);
293306
} else {
294307
// First try the simple name
295308
agentClass =
296-
getType(agentClassName, BaseAgent.class)
309+
getType(effectiveAgentClassName, BaseAgent.class)
297310
// If not found, try with com.google.adk.agents prefix
298-
.or(() -> getType("com.google.adk.agents." + agentClassName, BaseAgent.class))
311+
.or(
312+
() ->
313+
getType("com.google.adk.agents." + effectiveAgentClassName, BaseAgent.class))
299314
// For Python compatibility, also try with google.adk.agents prefix
300-
.or(() -> getType("google.adk.agents." + agentClassName, BaseAgent.class));
315+
.or(() -> getType("google.adk.agents." + effectiveAgentClassName, BaseAgent.class));
301316
}
302317

303318
return agentClass.orElseThrow(
304319
() ->
305320
new IllegalArgumentException(
306321
"agentClass '"
307-
+ agentClassName
322+
+ effectiveAgentClassName
308323
+ "' is not in registry or not a subclass of BaseAgent."));
309324
}
310325

core/src/test/java/com/google/adk/utils/ComponentRegistryTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ public void testResolveToolClass_withFullyQualifiedName() {
298298

299299
@Test
300300
public void testMcpToolsetRegistration() {
301-
ComponentRegistry registry = ComponentRegistry.getInstance();
301+
ComponentRegistry registry = new ComponentRegistry();
302302

303303
// Verify direct registry storage (tests lines 134, 136, 138, 142 in ComponentRegistry.java)
304304
Optional<Object> directFullName = registry.get("com.google.adk.tools.mcp.McpToolset");

0 commit comments

Comments
 (0)