|
30 | 30 | import java.lang.reflect.ParameterizedType; |
31 | 31 | import java.lang.reflect.Type; |
32 | 32 | import java.util.ArrayList; |
| 33 | +import java.util.HashSet; |
33 | 34 | import java.util.LinkedHashMap; |
34 | 35 | import java.util.List; |
35 | 36 | import java.util.Map; |
| 37 | +import java.util.Optional; |
| 38 | +import java.util.Set; |
| 39 | +import javax.annotation.Nullable; |
| 40 | +import org.slf4j.Logger; |
| 41 | +import org.slf4j.LoggerFactory; |
36 | 42 |
|
37 | 43 | /** Utility class for function calling. */ |
38 | 44 | public final class FunctionCallingUtils { |
39 | 45 |
|
40 | 46 | private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); |
| 47 | + private static final Logger logger = LoggerFactory.getLogger(FunctionCallingUtils.class); |
41 | 48 |
|
| 49 | + /** Holds the state during a single schema generation process to handle caching and recursion. */ |
| 50 | + private static class SchemaGenerationContext { |
| 51 | + private final Map<String, Schema> definitions = new LinkedHashMap<>(); |
| 52 | + private final Set<Type> processingStack = new HashSet<>(); |
| 53 | + |
| 54 | + boolean isProcessing(Type type) { |
| 55 | + return processingStack.contains(type); |
| 56 | + } |
| 57 | + |
| 58 | + void startProcessing(Type type) { |
| 59 | + processingStack.add(type); |
| 60 | + } |
| 61 | + |
| 62 | + void finishProcessing(Type type) { |
| 63 | + processingStack.remove(type); |
| 64 | + } |
| 65 | + |
| 66 | + Optional<Schema> getDefinition(String name) { |
| 67 | + return Optional.ofNullable(definitions.get(name)); |
| 68 | + } |
| 69 | + |
| 70 | + void addDefinition(String name, Schema schema) { |
| 71 | + definitions.put(name, schema); |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + /** |
| 76 | + * Builds a FunctionDeclaration from a Java Method, ignoring parameters with the given names. |
| 77 | + * |
| 78 | + * @param func The Java {@link Method} to convert into a FunctionDeclaration. |
| 79 | + * @param ignoreParams The names of parameters to ignore. |
| 80 | + * @return The generated {@link FunctionDeclaration}. |
| 81 | + * @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson. |
| 82 | + */ |
42 | 83 | public static FunctionDeclaration buildFunctionDeclaration( |
43 | 84 | Method func, List<String> ignoreParams) { |
44 | 85 | String name = |
@@ -106,42 +147,137 @@ private static Schema buildSchemaFromParameter(Parameter param) { |
106 | 147 | return schema; |
107 | 148 | } |
108 | 149 |
|
| 150 | + /** |
| 151 | + * Builds a Schema from a Java Type, creating a new context for the generation process. |
| 152 | + * |
| 153 | + * @param type The Java {@link Type} to convert into a Schema. |
| 154 | + * @return The generated {@link Schema}. |
| 155 | + * @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson. |
| 156 | + */ |
109 | 157 | public static Schema buildSchemaFromType(Type type) { |
110 | | - Schema.Builder builder = Schema.builder(); |
111 | | - if (type instanceof ParameterizedType parameterizedType) { |
112 | | - String rawTypeName = ((Class<?>) parameterizedType.getRawType()).getName(); |
113 | | - switch (rawTypeName) { |
114 | | - case "java.util.List", "com.google.common.collect.ImmutableList": |
115 | | - Schema itemSchema = buildSchemaFromType(parameterizedType.getActualTypeArguments()[0]); |
| 158 | + return buildSchemaRecursive(type, new SchemaGenerationContext()); |
| 159 | + } |
| 160 | + |
| 161 | + /** |
| 162 | + * Recursively builds a Schema from a Java Type using a context to manage recursion and caching. |
| 163 | + * |
| 164 | + * @param type The Java {@link Type} to convert. |
| 165 | + * @param context The {@link SchemaGenerationContext} for this generation task. |
| 166 | + * @return The generated {@link Schema}. |
| 167 | + * @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson. |
| 168 | + */ |
| 169 | + @SuppressWarnings("deprecation") // We don't have actual instances of the type |
| 170 | + private static Schema buildSchemaRecursive(Type type, SchemaGenerationContext context) { |
| 171 | + String definitionName = getTypeKey(type); |
| 172 | + |
| 173 | + if (definitionName != null) { |
| 174 | + if (context.isProcessing(type)) { |
| 175 | + logger.warn("Type {} is recursive. Omitting from schema.", type); |
| 176 | + return Schema.builder() |
| 177 | + .type("OBJECT") |
| 178 | + .description("Recursive reference to " + definitionName + " omitted.") |
| 179 | + .build(); |
| 180 | + } |
| 181 | + Optional<Schema> cachedSchema = context.getDefinition(definitionName); |
| 182 | + if (cachedSchema.isPresent()) { |
| 183 | + return cachedSchema.get(); |
| 184 | + } |
| 185 | + } |
| 186 | + |
| 187 | + context.startProcessing(type); |
| 188 | + |
| 189 | + Schema resultSchema; |
| 190 | + try { |
| 191 | + Schema.Builder builder = Schema.builder(); |
| 192 | + if (type instanceof ParameterizedType parameterizedType) { |
| 193 | + Class<?> rawClass = (Class<?>) parameterizedType.getRawType(); |
| 194 | + if (List.class.isAssignableFrom(rawClass)) { |
| 195 | + Schema itemSchema = |
| 196 | + buildSchemaRecursive(parameterizedType.getActualTypeArguments()[0], context); |
116 | 197 | builder.type("ARRAY").items(itemSchema); |
117 | | - break; |
118 | | - case "java.util.Map": |
119 | | - case "com.google.common.collect.ImmutableMap": |
| 198 | + } else if (Map.class.isAssignableFrom(rawClass)) { |
120 | 199 | builder.type("OBJECT"); |
121 | | - break; |
122 | | - default: |
123 | | - throw new IllegalArgumentException("Unsupported generic type: " + type); |
124 | | - } |
125 | | - } else if (type instanceof Class<?> clazz) { |
126 | | - switch (clazz.getName()) { |
127 | | - case "java.lang.String" -> builder.type("STRING"); |
128 | | - case "boolean", "java.lang.Boolean" -> builder.type("BOOLEAN"); |
129 | | - case "int", "java.lang.Integer" -> builder.type("INTEGER"); |
130 | | - case "double", "java.lang.Double", "float", "java.lang.Float", "long", "java.lang.Long" -> |
131 | | - builder.type("NUMBER"); |
132 | | - case "java.util.Map", "com.google.common.collect.ImmutableMap" -> builder.type("OBJECT"); |
133 | | - default -> { |
| 200 | + } else { |
| 201 | + // Fallback for other parameterized types (e.g., custom generics) is to inspect the |
| 202 | + // raw type. |
| 203 | + return buildSchemaRecursive(rawClass, context); |
| 204 | + } |
| 205 | + } else if (type instanceof Class<?> clazz) { |
| 206 | + if (clazz.isEnum()) { |
| 207 | + builder.type("STRING"); |
| 208 | + List<String> enumValues = new ArrayList<>(); |
| 209 | + for (Object enumConstant : clazz.getEnumConstants()) { |
| 210 | + enumValues.add(enumConstant.toString()); |
| 211 | + } |
| 212 | + builder.enum_(enumValues); |
| 213 | + } else if (String.class.equals(clazz)) { |
| 214 | + builder.type("STRING"); |
| 215 | + } else if (Boolean.class.equals(clazz) || boolean.class.equals(clazz)) { |
| 216 | + builder.type("BOOLEAN"); |
| 217 | + } else if (Integer.class.equals(clazz) || int.class.equals(clazz)) { |
| 218 | + builder.type("INTEGER"); |
| 219 | + } else if (Double.class.equals(clazz) |
| 220 | + || double.class.equals(clazz) |
| 221 | + || Float.class.equals(clazz) |
| 222 | + || float.class.equals(clazz) |
| 223 | + || Long.class.equals(clazz) |
| 224 | + || long.class.equals(clazz)) { |
| 225 | + builder.type("NUMBER"); |
| 226 | + } else if (Map.class.isAssignableFrom(clazz)) { |
| 227 | + builder.type("OBJECT"); |
| 228 | + } else { |
| 229 | + // Default to treating as a POJO. |
| 230 | + if (!OBJECT_MAPPER.canSerialize(clazz)) { |
| 231 | + throw new IllegalArgumentException( |
| 232 | + "Unsupported type: " |
| 233 | + + clazz.getName() |
| 234 | + + ". The type must be a Jackson-serializable POJO or a registered" |
| 235 | + + " primitive. Opaque types like Protobuf models are not supported" |
| 236 | + + " directly."); |
| 237 | + } |
134 | 238 | BeanDescription beanDescription = |
135 | 239 | OBJECT_MAPPER.getSerializationConfig().introspect(OBJECT_MAPPER.constructType(type)); |
136 | 240 | Map<String, Schema> properties = new LinkedHashMap<>(); |
137 | 241 | for (BeanPropertyDefinition property : beanDescription.findProperties()) { |
138 | | - properties.put(property.getName(), buildSchemaFromType(property.getRawPrimaryType())); |
| 242 | + Type propertyType = property.getRawPrimaryType(); |
| 243 | + if (propertyType == null) { |
| 244 | + continue; |
| 245 | + } |
| 246 | + properties.put(property.getName(), buildSchemaRecursive(propertyType, context)); |
139 | 247 | } |
140 | 248 | builder.type("OBJECT").properties(properties); |
141 | 249 | } |
142 | 250 | } |
| 251 | + resultSchema = builder.build(); |
| 252 | + } finally { |
| 253 | + context.finishProcessing(type); |
143 | 254 | } |
144 | | - return builder.build(); |
| 255 | + |
| 256 | + if (definitionName != null) { |
| 257 | + context.addDefinition(definitionName, resultSchema); |
| 258 | + } |
| 259 | + return resultSchema; |
| 260 | + } |
| 261 | + |
| 262 | + /** |
| 263 | + * Gets a stable, canonical name for a type to use as a key for caching and recursion tracking. |
| 264 | + * |
| 265 | + * @param type The type to name. |
| 266 | + * @return The canonical name of the type, or null if the type should not be tracked (e.g., |
| 267 | + * primitives). |
| 268 | + */ |
| 269 | + @Nullable |
| 270 | + private static String getTypeKey(Type type) { |
| 271 | + if (type instanceof Class<?> clazz) { |
| 272 | + if (clazz.isPrimitive() || clazz.isEnum() || clazz.getName().startsWith("java.")) { |
| 273 | + return null; |
| 274 | + } |
| 275 | + return clazz.getCanonicalName(); |
| 276 | + } |
| 277 | + if (type instanceof ParameterizedType pType) { |
| 278 | + return getTypeKey(pType.getRawType()); |
| 279 | + } |
| 280 | + return null; |
145 | 281 | } |
146 | 282 |
|
147 | 283 | private FunctionCallingUtils() {} |
|
0 commit comments