Skip to content

Commit a406c8b

Browse files
Poggeccicopybara-github
authored andcommitted
feat: Enforce serializable types for FunctionTools
PiperOrigin-RevId: 791448986
1 parent e1214c1 commit a406c8b

File tree

2 files changed

+179
-30
lines changed

2 files changed

+179
-30
lines changed

core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java

Lines changed: 158 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,53 @@
3030
import java.lang.reflect.ParameterizedType;
3131
import java.lang.reflect.Type;
3232
import java.util.ArrayList;
33+
import java.util.HashSet;
3334
import java.util.LinkedHashMap;
3435
import java.util.List;
3536
import java.util.Map;
37+
import java.util.Optional;
38+
import java.util.Set;
39+
import javax.annotation.Nullable;
3640

3741
/** Utility class for function calling. */
3842
public final class FunctionCallingUtils {
3943

4044
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
4145

46+
/** Holds the state during a single schema generation process to handle caching and recursion. */
47+
private static class SchemaGenerationContext {
48+
private final Map<String, Schema> definitions = new LinkedHashMap<>();
49+
private final Set<Type> processingStack = new HashSet<>();
50+
51+
boolean isProcessing(Type type) {
52+
return processingStack.contains(type);
53+
}
54+
55+
void startProcessing(Type type) {
56+
processingStack.add(type);
57+
}
58+
59+
void finishProcessing(Type type) {
60+
processingStack.remove(type);
61+
}
62+
63+
Optional<Schema> getDefinition(String name) {
64+
return Optional.ofNullable(definitions.get(name));
65+
}
66+
67+
void addDefinition(String name, Schema schema) {
68+
definitions.put(name, schema);
69+
}
70+
}
71+
72+
/**
73+
* Builds a FunctionDeclaration from a Java Method, ignoring parameters with the given names.
74+
*
75+
* @param func The Java {@link Method} to convert into a FunctionDeclaration.
76+
* @param ignoreParams The names of parameters to ignore.
77+
* @return The generated {@link FunctionDeclaration}.
78+
* @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson.
79+
*/
4280
public static FunctionDeclaration buildFunctionDeclaration(
4381
Method func, List<String> ignoreParams) {
4482
String name =
@@ -106,42 +144,132 @@ private static Schema buildSchemaFromParameter(Parameter param) {
106144
return schema;
107145
}
108146

147+
/**
148+
* Builds a Schema from a Java Type, creating a new context for the generation process.
149+
*
150+
* @param type The Java {@link Type} to convert into a Schema.
151+
* @return The generated {@link Schema}.
152+
* @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson.
153+
*/
109154
public static Schema buildSchemaFromType(Type type) {
110-
Schema.Builder builder = Schema.builder();
111-
if (type instanceof ParameterizedType parameterizedType) {
112-
String rawTypeName = ((Class<?>) parameterizedType.getRawType()).getName();
113-
switch (rawTypeName) {
114-
case "java.util.List", "com.google.common.collect.ImmutableList":
115-
Schema itemSchema = buildSchemaFromType(parameterizedType.getActualTypeArguments()[0]);
116-
builder.type("ARRAY").items(itemSchema);
117-
break;
118-
case "java.util.Map":
119-
case "com.google.common.collect.ImmutableMap":
120-
builder.type("OBJECT");
121-
break;
122-
default:
123-
throw new IllegalArgumentException("Unsupported generic type: " + type);
155+
return buildSchemaRecursive(type, new SchemaGenerationContext());
156+
}
157+
158+
/**
159+
* Recursively builds a Schema from a Java Type using a context to manage recursion and caching.
160+
*
161+
* @param type The Java {@link Type} to convert.
162+
* @param context The {@link SchemaGenerationContext} for this generation task.
163+
* @return The generated {@link Schema}.
164+
* @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson.
165+
*/
166+
private static Schema buildSchemaRecursive(Type type, SchemaGenerationContext context) {
167+
String definitionName = getCanonicalName(type);
168+
169+
if (definitionName != null) {
170+
if (context.isProcessing(type)) {
171+
return Schema.builder()
172+
.type("OBJECT")
173+
.description("Recursive reference to " + definitionName + " omitted.")
174+
.build();
175+
}
176+
Optional<Schema> cachedSchema = context.getDefinition(definitionName);
177+
if (cachedSchema.isPresent()) {
178+
return cachedSchema.get();
124179
}
125-
} else if (type instanceof Class<?> clazz) {
126-
switch (clazz.getName()) {
127-
case "java.lang.String" -> builder.type("STRING");
128-
case "boolean", "java.lang.Boolean" -> builder.type("BOOLEAN");
129-
case "int", "java.lang.Integer" -> builder.type("INTEGER");
130-
case "double", "java.lang.Double", "float", "java.lang.Float", "long", "java.lang.Long" ->
131-
builder.type("NUMBER");
132-
case "java.util.Map", "com.google.common.collect.ImmutableMap" -> builder.type("OBJECT");
133-
default -> {
134-
BeanDescription beanDescription =
135-
OBJECT_MAPPER.getSerializationConfig().introspect(OBJECT_MAPPER.constructType(type));
136-
Map<String, Schema> properties = new LinkedHashMap<>();
137-
for (BeanPropertyDefinition property : beanDescription.findProperties()) {
138-
properties.put(property.getName(), buildSchemaFromType(property.getRawPrimaryType()));
180+
}
181+
182+
context.startProcessing(type);
183+
184+
Schema resultSchema;
185+
try {
186+
Schema.Builder builder = Schema.builder();
187+
if (type instanceof ParameterizedType parameterizedType) {
188+
String rawTypeName = ((Class<?>) parameterizedType.getRawType()).getName();
189+
switch (rawTypeName) {
190+
case "java.util.List", "com.google.common.collect.ImmutableList":
191+
Schema itemSchema =
192+
buildSchemaRecursive(parameterizedType.getActualTypeArguments()[0], context);
193+
builder.type("ARRAY").items(itemSchema);
194+
break;
195+
case "java.util.Map", "com.google.common.collect.ImmutableMap":
196+
builder.type("OBJECT");
197+
break;
198+
default:
199+
return buildSchemaRecursive(parameterizedType.getRawType(), context);
200+
}
201+
} else if (type instanceof Class<?> clazz) {
202+
if (clazz.isEnum()) {
203+
builder.type("STRING");
204+
} else {
205+
switch (clazz.getName()) {
206+
case "java.lang.String" -> builder.type("STRING");
207+
case "boolean", "java.lang.Boolean" -> builder.type("BOOLEAN");
208+
case "int", "java.lang.Integer" -> builder.type("INTEGER");
209+
case "double",
210+
"java.lang.Double",
211+
"float",
212+
"java.lang.Float",
213+
"long",
214+
"java.lang.Long" ->
215+
builder.type("NUMBER");
216+
case "java.util.Map", "com.google.common.collect.ImmutableMap" ->
217+
builder.type("OBJECT");
218+
default -> {
219+
if (!OBJECT_MAPPER.canSerialize(clazz)) {
220+
throw new IllegalArgumentException(
221+
"Unsupported type: "
222+
+ clazz.getName()
223+
+ ". The type must be a Jackson-serializable POJO or a registered"
224+
+ " primitive. Opaque types like Protobuf models are not supported"
225+
+ " directly.");
226+
}
227+
BeanDescription beanDescription =
228+
OBJECT_MAPPER
229+
.getSerializationConfig()
230+
.introspect(OBJECT_MAPPER.constructType(type));
231+
Map<String, Schema> properties = new LinkedHashMap<>();
232+
for (BeanPropertyDefinition property : beanDescription.findProperties()) {
233+
Type propertyType = property.getRawPrimaryType();
234+
if (propertyType == null) {
235+
continue;
236+
}
237+
properties.put(property.getName(), buildSchemaRecursive(propertyType, context));
238+
}
239+
builder.type("OBJECT").properties(properties);
240+
}
139241
}
140-
builder.type("OBJECT").properties(properties);
141242
}
142243
}
244+
resultSchema = builder.build();
245+
} finally {
246+
context.finishProcessing(type);
143247
}
144-
return builder.build();
248+
249+
if (definitionName != null) {
250+
context.addDefinition(definitionName, resultSchema);
251+
}
252+
return resultSchema;
253+
}
254+
255+
/**
256+
* Gets a stable, canonical name for a type to use as a key for caching and recursion tracking.
257+
*
258+
* @param type The type to name.
259+
* @return A simple string name, or null if the type should not be tracked (e.g., primitives).
260+
*/
261+
@Nullable
262+
private static String getCanonicalName(Type type) {
263+
if (type instanceof Class<?> clazz) {
264+
if (clazz.isPrimitive() || clazz.isEnum() || clazz.getName().startsWith("java.")) {
265+
return null;
266+
}
267+
return clazz.getSimpleName();
268+
}
269+
if (type instanceof ParameterizedType pType) {
270+
return getCanonicalName(pType.getRawType());
271+
}
272+
return null;
145273
}
146274

147275
private FunctionCallingUtils() {}

core/src/test/java/com/google/adk/tools/FunctionToolTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import com.google.common.collect.ImmutableMap;
2626
import com.google.genai.types.FunctionDeclaration;
2727
import com.google.genai.types.Schema;
28+
import com.google.protobuf.Timestamp;
2829
import io.reactivex.rxjava3.core.Maybe;
2930
import io.reactivex.rxjava3.core.Single;
3031
import java.lang.reflect.Method;
@@ -38,6 +39,13 @@
3839
/** Unit tests for {@link FunctionTool}. */
3940
@RunWith(JUnit4.class)
4041
public final class FunctionToolTest {
42+
43+
@Test
44+
public void create_withRecursiveParameter_raisesIllegalArgumentException() {
45+
assertThrows(
46+
IllegalArgumentException.class, () -> FunctionTool.create(Functions.class, "doThing"));
47+
}
48+
4149
@Test
4250
public void create_withStaticMethod_success() throws NoSuchMethodException {
4351
Method method = Functions.class.getMethod("voidReturnWithoutSchema");
@@ -429,6 +437,18 @@ public void call_nonStaticWithAllSupportedParameterTypes() throws Exception {
429437
}
430438

431439
static class Functions {
440+
441+
@Annotations.Schema(
442+
name = "doThing",
443+
description = "This function fetches stats that the agent needs.")
444+
public static Maybe<Map<String, Object>> doThing(
445+
@Annotations.Schema(
446+
name = "recursiveParam",
447+
description = "Protobuf fields have a recursive property in them.")
448+
Timestamp recursiveParam) {
449+
return Maybe.just(ImmutableMap.of("key", "value"));
450+
}
451+
432452
@Annotations.Schema(name = "my_function", description = "A test function")
433453
public static void voidReturnWithSchemaAndToolContext(
434454
@Annotations.Schema(name = "first_param", description = "An integer parameter") int param1,
@@ -582,4 +602,5 @@ public void setField2(int value) {
582602
privateField2 = value;
583603
}
584604
}
605+
585606
}

0 commit comments

Comments
 (0)