diff --git a/core/src/main/java/com/google/adk/tools/Annotations.java b/core/src/main/java/com/google/adk/tools/Annotations.java index 6aad15ea..d2f3fbf2 100644 --- a/core/src/main/java/com/google/adk/tools/Annotations.java +++ b/core/src/main/java/com/google/adk/tools/Annotations.java @@ -33,6 +33,8 @@ public final class Annotations { String name() default ""; String description() default ""; + + boolean optional() default false; } private Annotations() {} diff --git a/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java b/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java index fa26530c..dec33afb 100644 --- a/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java +++ b/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java @@ -104,7 +104,10 @@ public static FunctionDeclaration buildFunctionDeclaration( if (ignoreParams.contains(paramName)) { continue; } - required.add(paramName); + Annotations.Schema schema = param.getAnnotation(Annotations.Schema.class); + if (schema == null || !schema.optional()) { + required.add(paramName); + } properties.put(paramName, buildSchemaFromParameter(param)); } builder.parameters( diff --git a/core/src/main/java/com/google/adk/tools/FunctionTool.java b/core/src/main/java/com/google/adk/tools/FunctionTool.java index 73d22f8d..9ba28737 100644 --- a/core/src/main/java/com/google/adk/tools/FunctionTool.java +++ b/core/src/main/java/com/google/adk/tools/FunctionTool.java @@ -197,11 +197,17 @@ private Maybe> call(Map args, ToolContext to arguments[i] = null; continue; } + Annotations.Schema schema = parameters[i].getAnnotation(Annotations.Schema.class); if (!args.containsKey(paramName)) { - throw new IllegalArgumentException( - String.format( - "The parameter '%s' was not found in the arguments provided by the model.", - paramName)); + if (schema != null && schema.optional()) { + arguments[i] = null; + continue; + } else { + throw new IllegalArgumentException( + String.format( + "The parameter '%s' was not found in the arguments provided by the model.", + paramName)); + } } Class paramType = parameters[i].getType(); Object argValue = args.get(paramName); @@ -274,11 +280,17 @@ public Flowable> callLive( } continue; } + Annotations.Schema schema = parameters[i].getAnnotation(Annotations.Schema.class); if (!args.containsKey(paramName)) { - throw new IllegalArgumentException( - String.format( - "The parameter '%s' was not found in the arguments provided by the model.", - paramName)); + if (schema != null && schema.optional()) { + arguments[i] = null; + continue; + } else { + throw new IllegalArgumentException( + String.format( + "The parameter '%s' was not found in the arguments provided by the model.", + paramName)); + } } Class paramType = parameters[i].getType(); Object argValue = args.get(paramName); diff --git a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java index af9d7bf4..6576431a 100644 --- a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java +++ b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java @@ -472,6 +472,50 @@ public void create_withRecursiveParam_avoidsInfiniteRecursion() { assertThat(tool.declaration().get().parameters()).hasValue(expectedParameters); } + @Test + public void create_withOptionalParameter_excludesFromRequired() { + FunctionTool tool = FunctionTool.create(Functions.class, "functionWithOptionalParam"); + + assertThat(tool).isNotNull(); + assertThat(tool.declaration().get().parameters()) + .hasValue( + Schema.builder() + .type("OBJECT") + .properties( + ImmutableMap.of( + "requiredParam", + Schema.builder().type("STRING").description("A required parameter").build(), + "optionalParam", + Schema.builder() + .type("INTEGER") + .description("An optional parameter") + .build())) + .required(ImmutableList.of("requiredParam")) + .build()); + } + + @Test + public void call_withOptionalParameter_missingValue() throws Exception { + FunctionTool tool = FunctionTool.create(Functions.class, "functionWithOptionalParam"); + + Map result = + tool.runAsync(ImmutableMap.of("requiredParam", "test"), null).blockingGet(); + + assertThat(result) + .containsExactly( + "requiredParam", "test", "optionalParam", "null_value", "wasOptionalProvided", false); + } + + @Test + public void call_withOptionalParameter_missingRequired_returnsError() { + FunctionTool tool = FunctionTool.create(Functions.class, "functionWithOptionalParam"); + + Map result = + tool.runAsync(ImmutableMap.of("optionalParam", "test"), null).blockingGet(); + + assertThat(result).containsExactly("status", "error", "message", "An internal error occurred."); + } + @Test public void create_withMaybeMapReturnType() { FunctionTool tool = FunctionTool.create(Functions.class, "returnsMaybeMap"); @@ -719,6 +763,25 @@ public static ImmutableMap recursiveParam(Node param) { return ImmutableMap.of("param", param); } + public static ImmutableMap functionWithOptionalParam( + @Annotations.Schema(name = "requiredParam", description = "A required parameter") + String requiredParam, + @Annotations.Schema( + name = "optionalParam", + description = "An optional parameter", + optional = true) + Integer optionalParam) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + builder.put("requiredParam", requiredParam); + if (optionalParam != null) { + builder.put("optionalParam", optionalParam); + } else { + builder.put("optionalParam", "null_value"); + } + builder.put("wasOptionalProvided", optionalParam != null); + return builder.buildOrThrow(); + } + public ImmutableMap nonStaticReturnAllSupportedParametersAsMap( String stringParam, boolean primitiveBoolParam,