diff --git a/core/src/main/java/com/google/adk/tools/FunctionTool.java b/core/src/main/java/com/google/adk/tools/FunctionTool.java index 3037ef40..b5e35f43 100644 --- a/core/src/main/java/com/google/adk/tools/FunctionTool.java +++ b/core/src/main/java/com/google/adk/tools/FunctionTool.java @@ -181,55 +181,7 @@ public Single> runAsync(Map args, ToolContex @SuppressWarnings("unchecked") // For tool parameter type casting. private Maybe> call(Map args, ToolContext toolContext) throws IllegalAccessException, InvocationTargetException { - Parameter[] parameters = func.getParameters(); - Object[] arguments = new Object[parameters.length]; - for (int i = 0; i < parameters.length; i++) { - String paramName = - parameters[i].isAnnotationPresent(Annotations.Schema.class) - && !parameters[i].getAnnotation(Annotations.Schema.class).name().isEmpty() - ? parameters[i].getAnnotation(Annotations.Schema.class).name() - : parameters[i].getName(); - if (paramName.equals("toolContext")) { - arguments[i] = toolContext; - continue; - } - if (paramName.equals("inputStream")) { - arguments[i] = null; - continue; - } - if (!args.containsKey(paramName)) { - throw new IllegalArgumentException( - String.format( - "The parameter '%s' was not found in the arguments provided by the model.", - paramName)); - } - Class paramType = parameters[i].getType(); - Object argValue = args.get(paramName); - if (paramType.equals(List.class)) { - if (argValue instanceof List) { - Type type = - ((ParameterizedType) parameters[i].getParameterizedType()) - .getActualTypeArguments()[0]; - Class typeArgClass; - if (type instanceof Class) { - // Case 1: The argument is a simple class like String, Integer, etc. - typeArgClass = (Class) type; - } else if (type instanceof ParameterizedType pType) { - // Case 2: The argument is another parameterized type like Map - typeArgClass = (Class) pType.getRawType(); // Get the raw class (e.g., Map) - } else { - throw new IllegalArgumentException( - String.format("Unsupported parameterized type %s for '%s'", type, paramName)); - } - arguments[i] = createList((List) argValue, typeArgClass); - continue; - } - } else if (argValue instanceof Map) { - arguments[i] = OBJECT_MAPPER.convertValue(argValue, paramType); - continue; - } - arguments[i] = castValue(argValue, paramType); - } + Object[] arguments = buildArguments(args, toolContext, null); Object result = func.invoke(instance, arguments); if (result == null) { return Maybe.empty(); @@ -253,6 +205,21 @@ private Maybe> call(Map args, ToolContext to public Flowable> callLive( Map args, ToolContext toolContext, InvocationContext invocationContext) throws IllegalAccessException, InvocationTargetException { + Object[] arguments = buildArguments(args, toolContext, invocationContext); + Object result = func.invoke(instance, arguments); + if (result instanceof Flowable) { + return (Flowable>) result; + } else { + throw new IllegalArgumentException( + "callLive was called but the underlying function does not return a Flowable."); + } + } + + @SuppressWarnings("unchecked") // For tool parameter type casting. + private Object[] buildArguments( + Map args, + ToolContext toolContext, + @Nullable InvocationContext invocationContext) { Parameter[] parameters = func.getParameters(); Object[] arguments = new Object[parameters.length]; for (int i = 0; i < parameters.length; i++) { @@ -266,7 +233,8 @@ public Flowable> callLive( continue; } if (paramName.equals("inputStream")) { - if (invocationContext.activeStreamingTools().containsKey(this.name()) + if (invocationContext != null + && invocationContext.activeStreamingTools().containsKey(this.name()) && invocationContext.activeStreamingTools().get(this.name()).stream() != null) { arguments[i] = invocationContext.activeStreamingTools().get(this.name()).stream(); } else { @@ -287,7 +255,8 @@ public Flowable> callLive( Type type = ((ParameterizedType) parameters[i].getParameterizedType()) .getActualTypeArguments()[0]; - arguments[i] = createList((List) argValue, (Class) type); + Class typeArgClass = getTypeClass(type, paramName); + arguments[i] = createList((List) argValue, typeArgClass); continue; } } else if (argValue instanceof Map) { @@ -296,12 +265,19 @@ public Flowable> callLive( } arguments[i] = castValue(argValue, paramType); } - Object result = func.invoke(instance, arguments); - if (result instanceof Flowable) { - return (Flowable>) result; + return arguments; + } + + private static Class getTypeClass(Type type, String paramName) { + if (type instanceof Class) { + // Case 1: The argument is a simple class like String, Integer, etc. + return (Class) type; + } else if (type instanceof ParameterizedType pType) { + // Case 2: The argument is another parameterized type like Map + return (Class) pType.getRawType(); // Get the raw class (e.g., Map) } else { - logger.warn("callLive was called but the underlying function does not return a Flowable."); - return Flowable.empty(); + throw new IllegalArgumentException( + String.format("Unsupported parameterized type %s for '%s'", type, paramName)); } }