Skip to content

Commit f992b04

Browse files
Poggeccicopybara-github
authored andcommitted
feat: Enforce serializable types for FunctionTools
PiperOrigin-RevId: 791448986
1 parent e1214c1 commit f992b04

File tree

2 files changed

+294
-26
lines changed

2 files changed

+294
-26
lines changed

core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java

Lines changed: 160 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,56 @@
3030
import java.lang.reflect.ParameterizedType;
3131
import java.lang.reflect.Type;
3232
import java.util.ArrayList;
33+
import java.util.HashSet;
3334
import java.util.LinkedHashMap;
3435
import java.util.List;
3536
import java.util.Map;
37+
import java.util.Optional;
38+
import java.util.Set;
39+
import javax.annotation.Nullable;
40+
import org.slf4j.Logger;
41+
import org.slf4j.LoggerFactory;
3642

3743
/** Utility class for function calling. */
3844
public final class FunctionCallingUtils {
3945

4046
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
47+
private static final Logger logger = LoggerFactory.getLogger(FunctionCallingUtils.class);
4148

49+
/** Holds the state during a single schema generation process to handle caching and recursion. */
50+
private static class SchemaGenerationContext {
51+
private final Map<String, Schema> definitions = new LinkedHashMap<>();
52+
private final Set<Type> processingStack = new HashSet<>();
53+
54+
boolean isProcessing(Type type) {
55+
return processingStack.contains(type);
56+
}
57+
58+
void startProcessing(Type type) {
59+
processingStack.add(type);
60+
}
61+
62+
void finishProcessing(Type type) {
63+
processingStack.remove(type);
64+
}
65+
66+
Optional<Schema> getDefinition(String name) {
67+
return Optional.ofNullable(definitions.get(name));
68+
}
69+
70+
void addDefinition(String name, Schema schema) {
71+
definitions.put(name, schema);
72+
}
73+
}
74+
75+
/**
76+
* Builds a FunctionDeclaration from a Java Method, ignoring parameters with the given names.
77+
*
78+
* @param func The Java {@link Method} to convert into a FunctionDeclaration.
79+
* @param ignoreParams The names of parameters to ignore.
80+
* @return The generated {@link FunctionDeclaration}.
81+
* @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson.
82+
*/
4283
public static FunctionDeclaration buildFunctionDeclaration(
4384
Method func, List<String> ignoreParams) {
4485
String name =
@@ -106,42 +147,137 @@ private static Schema buildSchemaFromParameter(Parameter param) {
106147
return schema;
107148
}
108149

150+
/**
151+
* Builds a Schema from a Java Type, creating a new context for the generation process.
152+
*
153+
* @param type The Java {@link Type} to convert into a Schema.
154+
* @return The generated {@link Schema}.
155+
* @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson.
156+
*/
109157
public static Schema buildSchemaFromType(Type type) {
110-
Schema.Builder builder = Schema.builder();
111-
if (type instanceof ParameterizedType parameterizedType) {
112-
String rawTypeName = ((Class<?>) parameterizedType.getRawType()).getName();
113-
switch (rawTypeName) {
114-
case "java.util.List", "com.google.common.collect.ImmutableList":
115-
Schema itemSchema = buildSchemaFromType(parameterizedType.getActualTypeArguments()[0]);
158+
return buildSchemaRecursive(type, new SchemaGenerationContext());
159+
}
160+
161+
/**
162+
* Recursively builds a Schema from a Java Type using a context to manage recursion and caching.
163+
*
164+
* @param type The Java {@link Type} to convert.
165+
* @param context The {@link SchemaGenerationContext} for this generation task.
166+
* @return The generated {@link Schema}.
167+
* @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson.
168+
*/
169+
@SuppressWarnings("deprecation") // We don't have actual instances of the type
170+
private static Schema buildSchemaRecursive(Type type, SchemaGenerationContext context) {
171+
String definitionName = getTypeKey(type);
172+
173+
if (definitionName != null) {
174+
if (context.isProcessing(type)) {
175+
logger.warn("Type {} is recursive. Omitting from schema.", type);
176+
return Schema.builder()
177+
.type("OBJECT")
178+
.description("Recursive reference to " + definitionName + " omitted.")
179+
.build();
180+
}
181+
Optional<Schema> cachedSchema = context.getDefinition(definitionName);
182+
if (cachedSchema.isPresent()) {
183+
return cachedSchema.get();
184+
}
185+
}
186+
187+
context.startProcessing(type);
188+
189+
Schema resultSchema;
190+
try {
191+
Schema.Builder builder = Schema.builder();
192+
if (type instanceof ParameterizedType parameterizedType) {
193+
Class<?> rawClass = (Class<?>) parameterizedType.getRawType();
194+
if (List.class.isAssignableFrom(rawClass)) {
195+
Schema itemSchema =
196+
buildSchemaRecursive(parameterizedType.getActualTypeArguments()[0], context);
116197
builder.type("ARRAY").items(itemSchema);
117-
break;
118-
case "java.util.Map":
119-
case "com.google.common.collect.ImmutableMap":
198+
} else if (Map.class.isAssignableFrom(rawClass)) {
120199
builder.type("OBJECT");
121-
break;
122-
default:
123-
throw new IllegalArgumentException("Unsupported generic type: " + type);
124-
}
125-
} else if (type instanceof Class<?> clazz) {
126-
switch (clazz.getName()) {
127-
case "java.lang.String" -> builder.type("STRING");
128-
case "boolean", "java.lang.Boolean" -> builder.type("BOOLEAN");
129-
case "int", "java.lang.Integer" -> builder.type("INTEGER");
130-
case "double", "java.lang.Double", "float", "java.lang.Float", "long", "java.lang.Long" ->
131-
builder.type("NUMBER");
132-
case "java.util.Map", "com.google.common.collect.ImmutableMap" -> builder.type("OBJECT");
133-
default -> {
200+
} else {
201+
// Fallback for other parameterized types (e.g., custom generics) is to inspect the
202+
// raw type.
203+
return buildSchemaRecursive(rawClass, context);
204+
}
205+
} else if (type instanceof Class<?> clazz) {
206+
if (clazz.isEnum()) {
207+
builder.type("STRING");
208+
List<String> enumValues = new ArrayList<>();
209+
for (Object enumConstant : clazz.getEnumConstants()) {
210+
enumValues.add(enumConstant.toString());
211+
}
212+
builder.enum_(enumValues);
213+
} else if (String.class.equals(clazz)) {
214+
builder.type("STRING");
215+
} else if (Boolean.class.equals(clazz) || boolean.class.equals(clazz)) {
216+
builder.type("BOOLEAN");
217+
} else if (Integer.class.equals(clazz) || int.class.equals(clazz)) {
218+
builder.type("INTEGER");
219+
} else if (Double.class.equals(clazz)
220+
|| double.class.equals(clazz)
221+
|| Float.class.equals(clazz)
222+
|| float.class.equals(clazz)
223+
|| Long.class.equals(clazz)
224+
|| long.class.equals(clazz)) {
225+
builder.type("NUMBER");
226+
} else if (Map.class.isAssignableFrom(clazz)) {
227+
builder.type("OBJECT");
228+
} else {
229+
// Default to treating as a POJO.
230+
if (!OBJECT_MAPPER.canSerialize(clazz)) {
231+
throw new IllegalArgumentException(
232+
"Unsupported type: "
233+
+ clazz.getName()
234+
+ ". The type must be a Jackson-serializable POJO or a registered"
235+
+ " primitive. Opaque types like Protobuf models are not supported"
236+
+ " directly.");
237+
}
134238
BeanDescription beanDescription =
135239
OBJECT_MAPPER.getSerializationConfig().introspect(OBJECT_MAPPER.constructType(type));
136240
Map<String, Schema> properties = new LinkedHashMap<>();
137241
for (BeanPropertyDefinition property : beanDescription.findProperties()) {
138-
properties.put(property.getName(), buildSchemaFromType(property.getRawPrimaryType()));
242+
Type propertyType = property.getRawPrimaryType();
243+
if (propertyType == null) {
244+
continue;
245+
}
246+
properties.put(property.getName(), buildSchemaRecursive(propertyType, context));
139247
}
140248
builder.type("OBJECT").properties(properties);
141249
}
142250
}
251+
resultSchema = builder.build();
252+
} finally {
253+
context.finishProcessing(type);
143254
}
144-
return builder.build();
255+
256+
if (definitionName != null) {
257+
context.addDefinition(definitionName, resultSchema);
258+
}
259+
return resultSchema;
260+
}
261+
262+
/**
263+
* Gets a stable, canonical name for a type to use as a key for caching and recursion tracking.
264+
*
265+
* @param type The type to name.
266+
* @return The canonical name of the type, or null if the type should not be tracked (e.g.,
267+
* primitives).
268+
*/
269+
@Nullable
270+
private static String getTypeKey(Type type) {
271+
if (type instanceof Class<?> clazz) {
272+
if (clazz.isPrimitive() || clazz.isEnum() || clazz.getName().startsWith("java.")) {
273+
return null;
274+
}
275+
return clazz.getCanonicalName();
276+
}
277+
if (type instanceof ParameterizedType pType) {
278+
return getTypeKey(pType.getRawType());
279+
}
280+
return null;
145281
}
146282

147283
private FunctionCallingUtils() {}

0 commit comments

Comments
 (0)