Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
public class BaseAgentConfig {
private String name;
private String description = "";
// TODO: Add agentClassType enum to the config and handle different values from user
// input.e.g.LLM_AGENT, LlmAgent
private String agentClass = null;
private String agentClass;

@JsonProperty(value = "name", required = true)
public String name() {
Expand Down
54 changes: 13 additions & 41 deletions core/src/main/java/com/google/adk/agents/LlmAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package com.google.adk.agents;

import static com.google.common.base.Strings.isNullOrEmpty;
import static com.google.common.base.Strings.nullToEmpty;
import static java.util.stream.Collectors.joining;

Expand Down Expand Up @@ -47,7 +46,7 @@
import com.google.adk.flows.llmflows.BaseLlmFlow;
import com.google.adk.flows.llmflows.SingleFlow;
import com.google.adk.models.BaseLlm;
import com.google.adk.models.Gemini;
import com.google.adk.models.LlmRegistry;
import com.google.adk.models.Model;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.BaseToolset;
Expand Down Expand Up @@ -835,8 +834,17 @@ public Model resolvedModel() {
*/
private Model resolveModelInternal() {
if (this.model.isPresent()) {
if (this.model().isPresent()) {
return this.model.get();
Model currentModel = this.model.get();

if (currentModel.model().isPresent()) {
return currentModel;
}

if (currentModel.modelName().isPresent()) {
String modelName = currentModel.modelName().get();
BaseLlm resolvedLlm = LlmRegistry.getLlm(modelName);

return Model.builder().modelName(modelName).model(resolvedLlm).build();
}
}
BaseAgent current = this.parentAgent();
Expand Down Expand Up @@ -879,44 +887,8 @@ public static LlmAgent fromConfig(LlmAgentConfig config, String configAbsPath)
.description(nullToEmpty(config.description()))
.instruction(config.instruction());

// Set optional model configuration
if (config.model() != null && !config.model().trim().isEmpty()) {
logger.info("Configuring model: {}", config.model());

// TODO: resolve model name
if (config.model().startsWith("gemini")) {
try {
// Check for API key in system properties (for testing) or environment variables
String apiKey = System.getProperty("GOOGLE_API_KEY");
if (isNullOrEmpty(apiKey)) {
apiKey = System.getProperty("GEMINI_API_KEY");
}
if (isNullOrEmpty(apiKey)) {
apiKey = System.getenv("GOOGLE_API_KEY");
}
if (isNullOrEmpty(apiKey)) {
apiKey = System.getenv("GEMINI_API_KEY");
}

Gemini.Builder geminiBuilder = Gemini.builder().modelName(config.model());
if (apiKey != null && !apiKey.isEmpty()) {
geminiBuilder.apiKey(apiKey);
}

BaseLlm model = geminiBuilder.build();
builder.model(model);
logger.debug("Successfully configured Gemini model: {}", config.model());
} catch (RuntimeException e) {
logger.warn(
"Failed to create Gemini model '{}'. The agent will use the default LLM. Error: {}",
config.model(),
e.getMessage());
}
} else {
logger.warn(
"Model '{}' is not a supported Gemini model. The agent will use the default LLM.",
config.model());
}
builder.model(config.model());
}

// Set optional transfer configuration
Expand Down
65 changes: 65 additions & 0 deletions core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,69 @@ public void fromConfig_missingRequiredFields_throwsException() throws IOExceptio
assertThat(exception).hasMessageThat().contains("Failed to create agent from config");
assertThat(exception.getCause()).isNotNull();
}

@Test
public void fromConfig_withModel_setsModelOnAgent() throws IOException, ConfigurationException {
File configFile = tempFolder.newFile("with_model.yaml");
Files.writeString(
configFile.toPath(),
"name: modelAgent\n"
+ "description: Agent with a model\n"
+ "instruction: test instruction\n"
+ "agent_class: LlmAgent\n"
+ "model: \"gemini-pro\"\n");
String configPath = configFile.getAbsolutePath();

BaseAgent agent = ConfigAgentUtils.fromConfig(configPath);

assertThat(agent).isInstanceOf(LlmAgent.class);
LlmAgent llmAgent = (LlmAgent) agent;
assertThat(llmAgent.model()).isPresent();
assertThat(llmAgent.model().get().modelName()).hasValue("gemini-pro");
}

@Test
public void fromConfig_withEmptyModel_doesNotSetModelOnAgent()
throws IOException, ConfigurationException {
File configFile = tempFolder.newFile("empty_model.yaml");
Files.writeString(
configFile.toPath(),
"name: emptyModelAgent\n"
+ "description: Agent with an empty model\n"
+ "instruction: test instruction\n"
+ "agent_class: LlmAgent\n"
+ "model: \"\"\n");
String configPath = configFile.getAbsolutePath();

BaseAgent agent = ConfigAgentUtils.fromConfig(configPath);

assertThat(agent).isInstanceOf(LlmAgent.class);
LlmAgent llmAgent = (LlmAgent) agent;
assertThat(llmAgent.model()).isEmpty();
}

@Test
public void fromConfig_withInvalidModel_throwsExceptionOnModelResolution()
throws IOException, ConfigurationException {
File configFile = tempFolder.newFile("invalid_model.yaml");
Files.writeString(
configFile.toPath(),
"""
name: invalidModelAgent
description: Agent with an invalid model
instruction: test instruction
agent_class: LlmAgent
model: "invalid-model-name"
""");
String configPath = configFile.getAbsolutePath();

BaseAgent agent = ConfigAgentUtils.fromConfig(configPath);

assertThat(agent).isInstanceOf(LlmAgent.class);
LlmAgent llmAgent = (LlmAgent) agent;

IllegalArgumentException exception =
assertThrows(IllegalArgumentException.class, llmAgent::resolvedModel);
assertThat(exception).hasMessageThat().contains("invalid-model-name");
}
}
14 changes: 14 additions & 0 deletions core/src/test/java/com/google/adk/agents/LlmAgentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import static org.junit.Assert.assertThrows;

import com.google.adk.events.Event;
import com.google.adk.models.LlmRegistry;
import com.google.adk.models.LlmResponse;
import com.google.adk.models.Model;
import com.google.adk.testing.TestLlm;
import com.google.adk.testing.TestUtils.EchoTool;
import com.google.adk.tools.BaseTool;
Expand Down Expand Up @@ -289,4 +291,16 @@ public void testCanonicalGlobalInstruction_providerInstructionInjectsContext() {

assertThat(canonicalInstruction).isEqualTo(instruction + invocationContext.invocationId());
}

@Test
public void resolveModel_withModelName_resolvesFromRegistry() {
String modelName = "test-model";
TestLlm testLlm = createTestLlm(LlmResponse.builder().build());
LlmRegistry.registerLlm(modelName, (name) -> testLlm);
LlmAgent agent = createTestAgentBuilder(testLlm).model(modelName).build();
Model resolvedModel = agent.resolvedModel();

assertThat(resolvedModel.modelName()).hasValue(modelName);
assertThat(resolvedModel.model()).hasValue(testLlm);
}
}