Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/src/main/java/com/google/adk/tools/Annotations.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ public final class Annotations {
String name() default "";

String description() default "";

boolean optional() default false;
}

private Annotations() {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ public static FunctionDeclaration buildFunctionDeclaration(
if (ignoreParams.contains(paramName)) {
continue;
}
required.add(paramName);
Annotations.Schema schema = param.getAnnotation(Annotations.Schema.class);
if (schema == null || !schema.optional()) {
required.add(paramName);
}
properties.put(paramName, buildSchemaFromParameter(param));
}
builder.parameters(
Expand Down
28 changes: 20 additions & 8 deletions core/src/main/java/com/google/adk/tools/FunctionTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,17 @@ private Maybe<Map<String, Object>> call(Map<String, Object> args, ToolContext to
arguments[i] = null;
continue;
}
Annotations.Schema schema = parameters[i].getAnnotation(Annotations.Schema.class);
if (!args.containsKey(paramName)) {
throw new IllegalArgumentException(
String.format(
"The parameter '%s' was not found in the arguments provided by the model.",
paramName));
if (schema != null && schema.optional()) {
arguments[i] = null;
continue;
} else {
throw new IllegalArgumentException(
String.format(
"The parameter '%s' was not found in the arguments provided by the model.",
paramName));
}
}
Class<?> paramType = parameters[i].getType();
Object argValue = args.get(paramName);
Expand Down Expand Up @@ -274,11 +280,17 @@ public Flowable<Map<String, Object>> callLive(
}
continue;
}
Annotations.Schema schema = parameters[i].getAnnotation(Annotations.Schema.class);
if (!args.containsKey(paramName)) {
throw new IllegalArgumentException(
String.format(
"The parameter '%s' was not found in the arguments provided by the model.",
paramName));
if (schema != null && schema.optional()) {
arguments[i] = null;
continue;
} else {
throw new IllegalArgumentException(
String.format(
"The parameter '%s' was not found in the arguments provided by the model.",
paramName));
}
}
Class<?> paramType = parameters[i].getType();
Object argValue = args.get(paramName);
Expand Down
63 changes: 63 additions & 0 deletions core/src/test/java/com/google/adk/tools/FunctionToolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,50 @@ public void create_withRecursiveParam_avoidsInfiniteRecursion() {
assertThat(tool.declaration().get().parameters()).hasValue(expectedParameters);
}

@Test
public void create_withOptionalParameter_excludesFromRequired() {
FunctionTool tool = FunctionTool.create(Functions.class, "functionWithOptionalParam");

assertThat(tool).isNotNull();
assertThat(tool.declaration().get().parameters())
.hasValue(
Schema.builder()
.type("OBJECT")
.properties(
ImmutableMap.of(
"requiredParam",
Schema.builder().type("STRING").description("A required parameter").build(),
"optionalParam",
Schema.builder()
.type("INTEGER")
.description("An optional parameter")
.build()))
.required(ImmutableList.of("requiredParam"))
.build());
}

@Test
public void call_withOptionalParameter_missingValue() throws Exception {
FunctionTool tool = FunctionTool.create(Functions.class, "functionWithOptionalParam");

Map<String, Object> result =
tool.runAsync(ImmutableMap.of("requiredParam", "test"), null).blockingGet();

assertThat(result)
.containsExactly(
"requiredParam", "test", "optionalParam", "null_value", "wasOptionalProvided", false);
}

@Test
public void call_withOptionalParameter_missingRequired_returnsError() {
FunctionTool tool = FunctionTool.create(Functions.class, "functionWithOptionalParam");

Map<String, Object> result =
tool.runAsync(ImmutableMap.of("optionalParam", "test"), null).blockingGet();

assertThat(result).containsExactly("status", "error", "message", "An internal error occurred.");
}

@Test
public void create_withMaybeMapReturnType() {
FunctionTool tool = FunctionTool.create(Functions.class, "returnsMaybeMap");
Expand Down Expand Up @@ -719,6 +763,25 @@ public static ImmutableMap<String, Object> recursiveParam(Node param) {
return ImmutableMap.of("param", param);
}

public static ImmutableMap<String, Object> functionWithOptionalParam(
@Annotations.Schema(name = "requiredParam", description = "A required parameter")
String requiredParam,
@Annotations.Schema(
name = "optionalParam",
description = "An optional parameter",
optional = true)
Integer optionalParam) {
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
builder.put("requiredParam", requiredParam);
if (optionalParam != null) {
builder.put("optionalParam", optionalParam);
} else {
builder.put("optionalParam", "null_value");
}
builder.put("wasOptionalProvided", optionalParam != null);
return builder.buildOrThrow();
}

public ImmutableMap<String, Object> nonStaticReturnAllSupportedParametersAsMap(
String stringParam,
boolean primitiveBoolParam,
Expand Down