Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 160 additions & 24 deletions core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,56 @@
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Utility class for function calling. */
public final class FunctionCallingUtils {

private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final Logger logger = LoggerFactory.getLogger(FunctionCallingUtils.class);

/** Holds the state during a single schema generation process to handle caching and recursion. */
private static class SchemaGenerationContext {
private final Map<String, Schema> definitions = new LinkedHashMap<>();
private final Set<Type> processingStack = new HashSet<>();

boolean isProcessing(Type type) {
return processingStack.contains(type);
}

void startProcessing(Type type) {
processingStack.add(type);
}

void finishProcessing(Type type) {
processingStack.remove(type);
}

Optional<Schema> getDefinition(String name) {
return Optional.ofNullable(definitions.get(name));
}

void addDefinition(String name, Schema schema) {
definitions.put(name, schema);
}
}

/**
* Builds a FunctionDeclaration from a Java Method, ignoring parameters with the given names.
*
* @param func The Java {@link Method} to convert into a FunctionDeclaration.
* @param ignoreParams The names of parameters to ignore.
* @return The generated {@link FunctionDeclaration}.
* @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson.
*/
public static FunctionDeclaration buildFunctionDeclaration(
Method func, List<String> ignoreParams) {
String name =
Expand Down Expand Up @@ -106,42 +147,137 @@ private static Schema buildSchemaFromParameter(Parameter param) {
return schema;
}

/**
* Builds a Schema from a Java Type, creating a new context for the generation process.
*
* @param type The Java {@link Type} to convert into a Schema.
* @return The generated {@link Schema}.
* @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson.
*/
public static Schema buildSchemaFromType(Type type) {
Schema.Builder builder = Schema.builder();
if (type instanceof ParameterizedType parameterizedType) {
String rawTypeName = ((Class<?>) parameterizedType.getRawType()).getName();
switch (rawTypeName) {
case "java.util.List", "com.google.common.collect.ImmutableList":
Schema itemSchema = buildSchemaFromType(parameterizedType.getActualTypeArguments()[0]);
return buildSchemaRecursive(type, new SchemaGenerationContext());
}

/**
* Recursively builds a Schema from a Java Type using a context to manage recursion and caching.
*
* @param type The Java {@link Type} to convert.
* @param context The {@link SchemaGenerationContext} for this generation task.
* @return The generated {@link Schema}.
* @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson.
*/
@SuppressWarnings("deprecation") // We don't have actual instances of the type
private static Schema buildSchemaRecursive(Type type, SchemaGenerationContext context) {
String definitionName = getTypeKey(type);

if (definitionName != null) {
if (context.isProcessing(type)) {
logger.warn("Type {} is recursive. Omitting from schema.", type);
return Schema.builder()
.type("OBJECT")
.description("Recursive reference to " + definitionName + " omitted.")
.build();
}
Optional<Schema> cachedSchema = context.getDefinition(definitionName);
if (cachedSchema.isPresent()) {
return cachedSchema.get();
}
}

context.startProcessing(type);

Schema resultSchema;
try {
Schema.Builder builder = Schema.builder();
if (type instanceof ParameterizedType parameterizedType) {
Class<?> rawClass = (Class<?>) parameterizedType.getRawType();
if (List.class.isAssignableFrom(rawClass)) {
Schema itemSchema =
buildSchemaRecursive(parameterizedType.getActualTypeArguments()[0], context);
builder.type("ARRAY").items(itemSchema);
break;
case "java.util.Map":
case "com.google.common.collect.ImmutableMap":
} else if (Map.class.isAssignableFrom(rawClass)) {
builder.type("OBJECT");
break;
default:
throw new IllegalArgumentException("Unsupported generic type: " + type);
}
} else if (type instanceof Class<?> clazz) {
switch (clazz.getName()) {
case "java.lang.String" -> builder.type("STRING");
case "boolean", "java.lang.Boolean" -> builder.type("BOOLEAN");
case "int", "java.lang.Integer" -> builder.type("INTEGER");
case "double", "java.lang.Double", "float", "java.lang.Float", "long", "java.lang.Long" ->
builder.type("NUMBER");
case "java.util.Map", "com.google.common.collect.ImmutableMap" -> builder.type("OBJECT");
default -> {
} else {
// Fallback for other parameterized types (e.g., custom generics) is to inspect the
// raw type.
return buildSchemaRecursive(rawClass, context);
}
} else if (type instanceof Class<?> clazz) {
if (clazz.isEnum()) {
builder.type("STRING");
List<String> enumValues = new ArrayList<>();
for (Object enumConstant : clazz.getEnumConstants()) {
enumValues.add(enumConstant.toString());
}
builder.enum_(enumValues);
} else if (String.class.equals(clazz)) {
builder.type("STRING");
} else if (Boolean.class.equals(clazz) || boolean.class.equals(clazz)) {
builder.type("BOOLEAN");
} else if (Integer.class.equals(clazz) || int.class.equals(clazz)) {
builder.type("INTEGER");
} else if (Double.class.equals(clazz)
|| double.class.equals(clazz)
|| Float.class.equals(clazz)
|| float.class.equals(clazz)
|| Long.class.equals(clazz)
|| long.class.equals(clazz)) {
builder.type("NUMBER");
} else if (Map.class.isAssignableFrom(clazz)) {
builder.type("OBJECT");
} else {
// Default to treating as a POJO.
if (!OBJECT_MAPPER.canSerialize(clazz)) {
throw new IllegalArgumentException(
"Unsupported type: "
+ clazz.getName()
+ ". The type must be a Jackson-serializable POJO or a registered"
+ " primitive. Opaque types like Protobuf models are not supported"
+ " directly.");
}
BeanDescription beanDescription =
OBJECT_MAPPER.getSerializationConfig().introspect(OBJECT_MAPPER.constructType(type));
Map<String, Schema> properties = new LinkedHashMap<>();
for (BeanPropertyDefinition property : beanDescription.findProperties()) {
properties.put(property.getName(), buildSchemaFromType(property.getRawPrimaryType()));
Type propertyType = property.getRawPrimaryType();
if (propertyType == null) {
continue;
}
properties.put(property.getName(), buildSchemaRecursive(propertyType, context));
}
builder.type("OBJECT").properties(properties);
}
}
resultSchema = builder.build();
} finally {
context.finishProcessing(type);
}
return builder.build();

if (definitionName != null) {
context.addDefinition(definitionName, resultSchema);
}
return resultSchema;
}

/**
* Gets a stable, canonical name for a type to use as a key for caching and recursion tracking.
*
* @param type The type to name.
* @return The canonical name of the type, or null if the type should not be tracked (e.g.,
* primitives).
*/
@Nullable
private static String getTypeKey(Type type) {
if (type instanceof Class<?> clazz) {
if (clazz.isPrimitive() || clazz.isEnum() || clazz.getName().startsWith("java.")) {
return null;
}
return clazz.getCanonicalName();
}
if (type instanceof ParameterizedType pType) {
return getTypeKey(pType.getRawType());
}
return null;
}

private FunctionCallingUtils() {}
Expand Down
Loading