|
30 | 30 | import java.lang.reflect.ParameterizedType; |
31 | 31 | import java.lang.reflect.Type; |
32 | 32 | import java.util.ArrayList; |
| 33 | +import java.util.HashSet; |
33 | 34 | import java.util.LinkedHashMap; |
34 | 35 | import java.util.List; |
35 | 36 | import java.util.Map; |
| 37 | +import java.util.Optional; |
| 38 | +import java.util.Set; |
| 39 | +import javax.annotation.Nullable; |
36 | 40 |
|
37 | 41 | /** Utility class for function calling. */ |
38 | 42 | public final class FunctionCallingUtils { |
39 | 43 |
|
40 | 44 | private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); |
41 | 45 |
|
| 46 | + /** Holds the state during a single schema generation process to handle caching and recursion. */ |
| 47 | + private static class SchemaGenerationContext { |
| 48 | + private final Map<String, Schema> definitions = new LinkedHashMap<>(); |
| 49 | + private final Set<Type> processingStack = new HashSet<>(); |
| 50 | + |
| 51 | + boolean isProcessing(Type type) { |
| 52 | + return processingStack.contains(type); |
| 53 | + } |
| 54 | + |
| 55 | + void startProcessing(Type type) { |
| 56 | + processingStack.add(type); |
| 57 | + } |
| 58 | + |
| 59 | + void finishProcessing(Type type) { |
| 60 | + processingStack.remove(type); |
| 61 | + } |
| 62 | + |
| 63 | + Optional<Schema> getDefinition(String name) { |
| 64 | + return Optional.ofNullable(definitions.get(name)); |
| 65 | + } |
| 66 | + |
| 67 | + void addDefinition(String name, Schema schema) { |
| 68 | + definitions.put(name, schema); |
| 69 | + } |
| 70 | + } |
| 71 | + |
| 72 | + /** |
| 73 | + * Builds a FunctionDeclaration from a Java Method, ignoring parameters with the given names. |
| 74 | + * |
| 75 | + * @param func The Java {@link Method} to convert into a FunctionDeclaration. |
| 76 | + * @param ignoreParams The names of parameters to ignore. |
| 77 | + * @return The generated {@link FunctionDeclaration}. |
| 78 | + * @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson. |
| 79 | + */ |
42 | 80 | public static FunctionDeclaration buildFunctionDeclaration( |
43 | 81 | Method func, List<String> ignoreParams) { |
44 | 82 | String name = |
@@ -106,42 +144,132 @@ private static Schema buildSchemaFromParameter(Parameter param) { |
106 | 144 | return schema; |
107 | 145 | } |
108 | 146 |
|
| 147 | + /** |
| 148 | + * Builds a Schema from a Java Type, creating a new context for the generation process. |
| 149 | + * |
| 150 | + * @param type The Java {@link Type} to convert into a Schema. |
| 151 | + * @return The generated {@link Schema}. |
| 152 | + * @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson. |
| 153 | + */ |
109 | 154 | public static Schema buildSchemaFromType(Type type) { |
110 | | - Schema.Builder builder = Schema.builder(); |
111 | | - if (type instanceof ParameterizedType parameterizedType) { |
112 | | - String rawTypeName = ((Class<?>) parameterizedType.getRawType()).getName(); |
113 | | - switch (rawTypeName) { |
114 | | - case "java.util.List", "com.google.common.collect.ImmutableList": |
115 | | - Schema itemSchema = buildSchemaFromType(parameterizedType.getActualTypeArguments()[0]); |
116 | | - builder.type("ARRAY").items(itemSchema); |
117 | | - break; |
118 | | - case "java.util.Map": |
119 | | - case "com.google.common.collect.ImmutableMap": |
120 | | - builder.type("OBJECT"); |
121 | | - break; |
122 | | - default: |
123 | | - throw new IllegalArgumentException("Unsupported generic type: " + type); |
| 155 | + return buildSchemaRecursive(type, new SchemaGenerationContext()); |
| 156 | + } |
| 157 | + |
| 158 | + /** |
| 159 | + * Recursively builds a Schema from a Java Type using a context to manage recursion and caching. |
| 160 | + * |
| 161 | + * @param type The Java {@link Type} to convert. |
| 162 | + * @param context The {@link SchemaGenerationContext} for this generation task. |
| 163 | + * @return The generated {@link Schema}. |
| 164 | + * @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson. |
| 165 | + */ |
| 166 | + private static Schema buildSchemaRecursive(Type type, SchemaGenerationContext context) { |
| 167 | + String definitionName = getCanonicalName(type); |
| 168 | + |
| 169 | + if (definitionName != null) { |
| 170 | + if (context.isProcessing(type)) { |
| 171 | + return Schema.builder() |
| 172 | + .type("OBJECT") |
| 173 | + .description("Recursive reference to " + definitionName + " omitted.") |
| 174 | + .build(); |
| 175 | + } |
| 176 | + Optional<Schema> cachedSchema = context.getDefinition(definitionName); |
| 177 | + if (cachedSchema.isPresent()) { |
| 178 | + return cachedSchema.get(); |
124 | 179 | } |
125 | | - } else if (type instanceof Class<?> clazz) { |
126 | | - switch (clazz.getName()) { |
127 | | - case "java.lang.String" -> builder.type("STRING"); |
128 | | - case "boolean", "java.lang.Boolean" -> builder.type("BOOLEAN"); |
129 | | - case "int", "java.lang.Integer" -> builder.type("INTEGER"); |
130 | | - case "double", "java.lang.Double", "float", "java.lang.Float", "long", "java.lang.Long" -> |
131 | | - builder.type("NUMBER"); |
132 | | - case "java.util.Map", "com.google.common.collect.ImmutableMap" -> builder.type("OBJECT"); |
133 | | - default -> { |
134 | | - BeanDescription beanDescription = |
135 | | - OBJECT_MAPPER.getSerializationConfig().introspect(OBJECT_MAPPER.constructType(type)); |
136 | | - Map<String, Schema> properties = new LinkedHashMap<>(); |
137 | | - for (BeanPropertyDefinition property : beanDescription.findProperties()) { |
138 | | - properties.put(property.getName(), buildSchemaFromType(property.getRawPrimaryType())); |
| 180 | + } |
| 181 | + |
| 182 | + context.startProcessing(type); |
| 183 | + |
| 184 | + Schema resultSchema; |
| 185 | + try { |
| 186 | + Schema.Builder builder = Schema.builder(); |
| 187 | + if (type instanceof ParameterizedType parameterizedType) { |
| 188 | + String rawTypeName = ((Class<?>) parameterizedType.getRawType()).getName(); |
| 189 | + switch (rawTypeName) { |
| 190 | + case "java.util.List", "com.google.common.collect.ImmutableList": |
| 191 | + Schema itemSchema = |
| 192 | + buildSchemaRecursive(parameterizedType.getActualTypeArguments()[0], context); |
| 193 | + builder.type("ARRAY").items(itemSchema); |
| 194 | + break; |
| 195 | + case "java.util.Map", "com.google.common.collect.ImmutableMap": |
| 196 | + builder.type("OBJECT"); |
| 197 | + break; |
| 198 | + default: |
| 199 | + return buildSchemaRecursive(parameterizedType.getRawType(), context); |
| 200 | + } |
| 201 | + } else if (type instanceof Class<?> clazz) { |
| 202 | + if (clazz.isEnum()) { |
| 203 | + builder.type("STRING"); |
| 204 | + } else { |
| 205 | + switch (clazz.getName()) { |
| 206 | + case "java.lang.String" -> builder.type("STRING"); |
| 207 | + case "boolean", "java.lang.Boolean" -> builder.type("BOOLEAN"); |
| 208 | + case "int", "java.lang.Integer" -> builder.type("INTEGER"); |
| 209 | + case "double", |
| 210 | + "java.lang.Double", |
| 211 | + "float", |
| 212 | + "java.lang.Float", |
| 213 | + "long", |
| 214 | + "java.lang.Long" -> |
| 215 | + builder.type("NUMBER"); |
| 216 | + case "java.util.Map", "com.google.common.collect.ImmutableMap" -> |
| 217 | + builder.type("OBJECT"); |
| 218 | + default -> { |
| 219 | + if (!OBJECT_MAPPER.canSerialize(clazz)) { |
| 220 | + throw new IllegalArgumentException( |
| 221 | + "Unsupported type: " |
| 222 | + + clazz.getName() |
| 223 | + + ". The type must be a Jackson-serializable POJO or a registered" |
| 224 | + + " primitive. Opaque types like Protobuf models are not supported" |
| 225 | + + " directly."); |
| 226 | + } |
| 227 | + BeanDescription beanDescription = |
| 228 | + OBJECT_MAPPER |
| 229 | + .getSerializationConfig() |
| 230 | + .introspect(OBJECT_MAPPER.constructType(type)); |
| 231 | + Map<String, Schema> properties = new LinkedHashMap<>(); |
| 232 | + for (BeanPropertyDefinition property : beanDescription.findProperties()) { |
| 233 | + Type propertyType = property.getRawPrimaryType(); |
| 234 | + if (propertyType == null) { |
| 235 | + continue; |
| 236 | + } |
| 237 | + properties.put(property.getName(), buildSchemaRecursive(propertyType, context)); |
| 238 | + } |
| 239 | + builder.type("OBJECT").properties(properties); |
| 240 | + } |
139 | 241 | } |
140 | | - builder.type("OBJECT").properties(properties); |
141 | 242 | } |
142 | 243 | } |
| 244 | + resultSchema = builder.build(); |
| 245 | + } finally { |
| 246 | + context.finishProcessing(type); |
143 | 247 | } |
144 | | - return builder.build(); |
| 248 | + |
| 249 | + if (definitionName != null) { |
| 250 | + context.addDefinition(definitionName, resultSchema); |
| 251 | + } |
| 252 | + return resultSchema; |
| 253 | + } |
| 254 | + |
| 255 | + /** |
| 256 | + * Gets a stable, canonical name for a type to use as a key for caching and recursion tracking. |
| 257 | + * |
| 258 | + * @param type The type to name. |
| 259 | + * @return A simple string name, or null if the type should not be tracked (e.g., primitives). |
| 260 | + */ |
| 261 | + @Nullable |
| 262 | + private static String getCanonicalName(Type type) { |
| 263 | + if (type instanceof Class<?> clazz) { |
| 264 | + if (clazz.isPrimitive() || clazz.isEnum() || clazz.getName().startsWith("java.")) { |
| 265 | + return null; |
| 266 | + } |
| 267 | + return clazz.getSimpleName(); |
| 268 | + } |
| 269 | + if (type instanceof ParameterizedType pType) { |
| 270 | + return getCanonicalName(pType.getRawType()); |
| 271 | + } |
| 272 | + return null; |
145 | 273 | } |
146 | 274 |
|
147 | 275 | private FunctionCallingUtils() {} |
|
0 commit comments