Skip to content

Commit c2f2a5a

Browse files
Doris26copybara-github
authored andcommitted
feat: Update model resolution logic for LLM agents
This change refactors the model resolution logic within the LlmAgent to utilize the LlmRegistry for looking up model instances by name. It also introduces a canonicalModel() method to ensure a BaseLlm instance is consistently available, simplifying model access in flows. PiperOrigin-RevId: 791915369
1 parent 9723f8a commit c2f2a5a

File tree

4 files changed

+93
-44
lines changed

4 files changed

+93
-44
lines changed

core/src/main/java/com/google/adk/agents/BaseAgentConfig.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
public class BaseAgentConfig {
2727
private String name;
2828
private String description = "";
29-
// TODO: Add agentClassType enum to the config and handle different values from user
30-
// input.e.g.LLM_AGENT, LlmAgent
31-
private String agentClass = null;
29+
private String agentClass;
3230

3331
@JsonProperty(value = "name", required = true)
3432
public String name() {

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package com.google.adk.agents;
1818

19-
import static com.google.common.base.Strings.isNullOrEmpty;
2019
import static com.google.common.base.Strings.nullToEmpty;
2120
import static java.util.stream.Collectors.joining;
2221

@@ -47,7 +46,7 @@
4746
import com.google.adk.flows.llmflows.BaseLlmFlow;
4847
import com.google.adk.flows.llmflows.SingleFlow;
4948
import com.google.adk.models.BaseLlm;
50-
import com.google.adk.models.Gemini;
49+
import com.google.adk.models.LlmRegistry;
5150
import com.google.adk.models.Model;
5251
import com.google.adk.tools.BaseTool;
5352
import com.google.adk.tools.BaseToolset;
@@ -835,8 +834,17 @@ public Model resolvedModel() {
835834
*/
836835
private Model resolveModelInternal() {
837836
if (this.model.isPresent()) {
838-
if (this.model().isPresent()) {
839-
return this.model.get();
837+
Model currentModel = this.model.get();
838+
839+
if (currentModel.model().isPresent()) {
840+
return currentModel;
841+
}
842+
843+
if (currentModel.modelName().isPresent()) {
844+
String modelName = currentModel.modelName().get();
845+
BaseLlm resolvedLlm = LlmRegistry.getLlm(modelName);
846+
847+
return Model.builder().modelName(modelName).model(resolvedLlm).build();
840848
}
841849
}
842850
BaseAgent current = this.parentAgent();
@@ -879,44 +887,8 @@ public static LlmAgent fromConfig(LlmAgentConfig config, String configAbsPath)
879887
.description(nullToEmpty(config.description()))
880888
.instruction(config.instruction());
881889

882-
// Set optional model configuration
883890
if (config.model() != null && !config.model().trim().isEmpty()) {
884-
logger.info("Configuring model: {}", config.model());
885-
886-
// TODO: resolve model name
887-
if (config.model().startsWith("gemini")) {
888-
try {
889-
// Check for API key in system properties (for testing) or environment variables
890-
String apiKey = System.getProperty("GOOGLE_API_KEY");
891-
if (isNullOrEmpty(apiKey)) {
892-
apiKey = System.getProperty("GEMINI_API_KEY");
893-
}
894-
if (isNullOrEmpty(apiKey)) {
895-
apiKey = System.getenv("GOOGLE_API_KEY");
896-
}
897-
if (isNullOrEmpty(apiKey)) {
898-
apiKey = System.getenv("GEMINI_API_KEY");
899-
}
900-
901-
Gemini.Builder geminiBuilder = Gemini.builder().modelName(config.model());
902-
if (apiKey != null && !apiKey.isEmpty()) {
903-
geminiBuilder.apiKey(apiKey);
904-
}
905-
906-
BaseLlm model = geminiBuilder.build();
907-
builder.model(model);
908-
logger.debug("Successfully configured Gemini model: {}", config.model());
909-
} catch (RuntimeException e) {
910-
logger.warn(
911-
"Failed to create Gemini model '{}'. The agent will use the default LLM. Error: {}",
912-
config.model(),
913-
e.getMessage());
914-
}
915-
} else {
916-
logger.warn(
917-
"Model '{}' is not a supported Gemini model. The agent will use the default LLM.",
918-
config.model());
919-
}
891+
builder.model(config.model());
920892
}
921893

922894
// Set optional transfer configuration

core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,4 +178,69 @@ public void fromConfig_missingRequiredFields_throwsException() throws IOExceptio
178178
assertThat(exception).hasMessageThat().contains("Failed to create agent from config");
179179
assertThat(exception.getCause()).isNotNull();
180180
}
181+
182+
@Test
183+
public void fromConfig_withModel_setsModelOnAgent() throws IOException, ConfigurationException {
184+
File configFile = tempFolder.newFile("with_model.yaml");
185+
Files.writeString(
186+
configFile.toPath(),
187+
"name: modelAgent\n"
188+
+ "description: Agent with a model\n"
189+
+ "instruction: test instruction\n"
190+
+ "agent_class: LlmAgent\n"
191+
+ "model: \"gemini-pro\"\n");
192+
String configPath = configFile.getAbsolutePath();
193+
194+
BaseAgent agent = ConfigAgentUtils.fromConfig(configPath);
195+
196+
assertThat(agent).isInstanceOf(LlmAgent.class);
197+
LlmAgent llmAgent = (LlmAgent) agent;
198+
assertThat(llmAgent.model()).isPresent();
199+
assertThat(llmAgent.model().get().modelName()).hasValue("gemini-pro");
200+
}
201+
202+
@Test
203+
public void fromConfig_withEmptyModel_doesNotSetModelOnAgent()
204+
throws IOException, ConfigurationException {
205+
File configFile = tempFolder.newFile("empty_model.yaml");
206+
Files.writeString(
207+
configFile.toPath(),
208+
"name: emptyModelAgent\n"
209+
+ "description: Agent with an empty model\n"
210+
+ "instruction: test instruction\n"
211+
+ "agent_class: LlmAgent\n"
212+
+ "model: \"\"\n");
213+
String configPath = configFile.getAbsolutePath();
214+
215+
BaseAgent agent = ConfigAgentUtils.fromConfig(configPath);
216+
217+
assertThat(agent).isInstanceOf(LlmAgent.class);
218+
LlmAgent llmAgent = (LlmAgent) agent;
219+
assertThat(llmAgent.model()).isEmpty();
220+
}
221+
222+
@Test
223+
public void fromConfig_withInvalidModel_throwsExceptionOnModelResolution()
224+
throws IOException, ConfigurationException {
225+
File configFile = tempFolder.newFile("invalid_model.yaml");
226+
Files.writeString(
227+
configFile.toPath(),
228+
"""
229+
name: invalidModelAgent
230+
description: Agent with an invalid model
231+
instruction: test instruction
232+
agent_class: LlmAgent
233+
model: "invalid-model-name"
234+
""");
235+
String configPath = configFile.getAbsolutePath();
236+
237+
BaseAgent agent = ConfigAgentUtils.fromConfig(configPath);
238+
239+
assertThat(agent).isInstanceOf(LlmAgent.class);
240+
LlmAgent llmAgent = (LlmAgent) agent;
241+
242+
IllegalArgumentException exception =
243+
assertThrows(IllegalArgumentException.class, llmAgent::resolvedModel);
244+
assertThat(exception).hasMessageThat().contains("invalid-model-name");
245+
}
181246
}

core/src/test/java/com/google/adk/agents/LlmAgentTest.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
import static org.junit.Assert.assertThrows;
2828

2929
import com.google.adk.events.Event;
30+
import com.google.adk.models.LlmRegistry;
3031
import com.google.adk.models.LlmResponse;
32+
import com.google.adk.models.Model;
3133
import com.google.adk.testing.TestLlm;
3234
import com.google.adk.testing.TestUtils.EchoTool;
3335
import com.google.adk.tools.BaseTool;
@@ -289,4 +291,16 @@ public void testCanonicalGlobalInstruction_providerInstructionInjectsContext() {
289291

290292
assertThat(canonicalInstruction).isEqualTo(instruction + invocationContext.invocationId());
291293
}
294+
295+
@Test
296+
public void resolveModel_withModelName_resolvesFromRegistry() {
297+
String modelName = "test-model";
298+
TestLlm testLlm = createTestLlm(LlmResponse.builder().build());
299+
LlmRegistry.registerLlm(modelName, (name) -> testLlm);
300+
LlmAgent agent = createTestAgentBuilder(testLlm).model(modelName).build();
301+
Model resolvedModel = agent.resolvedModel();
302+
303+
assertThat(resolvedModel.modelName()).hasValue(modelName);
304+
assertThat(resolvedModel.model()).hasValue(testLlm);
305+
}
292306
}

0 commit comments

Comments
 (0)