Skip to content

Commit 814ae54

Browse files
Poggeccicopybara-github
authored andcommitted
feat: Update FunctionTool to handle deserializing arbitrary return types
PiperOrigin-RevId: 788603173
1 parent a3746ed commit 814ae54

File tree

3 files changed

+77
-64
lines changed

3 files changed

+77
-64
lines changed

core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java

Lines changed: 24 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -70,26 +70,15 @@ public static FunctionDeclaration buildFunctionDeclaration(
7070

7171
Type returnType = func.getGenericReturnType();
7272
if (returnType != Void.TYPE) {
73-
Type realReturnType = returnType;
74-
if (returnType instanceof ParameterizedType) {
75-
ParameterizedType parameterizedReturnType = (ParameterizedType) returnType;
76-
String returnTypeName = ((Class<?>) parameterizedReturnType.getRawType()).getName();
77-
if (returnTypeName.equals("io.reactivex.rxjava3.core.Maybe")
78-
|| returnTypeName.equals("io.reactivex.rxjava3.core.Single")) {
79-
returnType = parameterizedReturnType.getActualTypeArguments()[0];
80-
if (returnType instanceof ParameterizedType) {
81-
ParameterizedType maybeParameterizedType = (ParameterizedType) returnType;
82-
returnTypeName = ((Class<?>) maybeParameterizedType.getRawType()).getName();
83-
}
84-
}
85-
if (returnTypeName.equals("java.util.Map")
86-
|| returnTypeName.equals("com.google.common.collect.ImmutableMap")) {
87-
return builder.response(buildSchemaFromType(returnType)).build();
73+
Type actualReturnType = returnType;
74+
if (returnType instanceof ParameterizedType parameterizedReturnType) {
75+
String rawTypeName = ((Class<?>) parameterizedReturnType.getRawType()).getName();
76+
if (rawTypeName.equals("io.reactivex.rxjava3.core.Maybe")
77+
|| rawTypeName.equals("io.reactivex.rxjava3.core.Single")) {
78+
actualReturnType = parameterizedReturnType.getActualTypeArguments()[0];
8879
}
8980
}
90-
throw new IllegalArgumentException(
91-
"Return type should be Map or Maybe<Map> or Single<Map>, but it was "
92-
+ realReturnType.getTypeName());
81+
builder.response(buildSchemaFromType(actualReturnType));
9382
}
9483
return builder.build();
9584
}
@@ -107,50 +96,31 @@ static FunctionDeclaration buildFunctionDeclaration(JsonBaseModel func, String d
10796
}
10897

10998
private static Schema buildSchemaFromParameter(Parameter param) {
110-
Schema.Builder builder = Schema.builder();
99+
Schema schema = buildSchemaFromType(param.getParameterizedType());
111100
if (param.isAnnotationPresent(Annotations.Schema.class)
112101
&& !param.getAnnotation(Annotations.Schema.class).description().isEmpty()) {
113-
builder.description(param.getAnnotation(Annotations.Schema.class).description());
114-
}
115-
switch (param.getType().getName()) {
116-
case "java.lang.String" -> builder.type("STRING");
117-
case "boolean", "java.lang.Boolean" -> builder.type("BOOLEAN");
118-
case "int", "java.lang.Integer" -> builder.type("INTEGER");
119-
case "double", "java.lang.Double", "float", "java.lang.Float", "long", "java.lang.Long" ->
120-
builder.type("NUMBER");
121-
case "java.util.List" ->
122-
builder
123-
.type("ARRAY")
124-
.items(
125-
buildSchemaFromType(
126-
((ParameterizedType) param.getParameterizedType())
127-
.getActualTypeArguments()[0]));
128-
case "java.util.Map" -> builder.type("OBJECT");
129-
default -> {
130-
BeanDescription beanDescription =
131-
OBJECT_MAPPER
132-
.getSerializationConfig()
133-
.introspect(OBJECT_MAPPER.constructType(param.getType()));
134-
Map<String, Schema> properties = new LinkedHashMap<>();
135-
for (BeanPropertyDefinition property : beanDescription.findProperties()) {
136-
properties.put(property.getName(), buildSchemaFromType(property.getRawPrimaryType()));
137-
}
138-
builder.type("OBJECT").properties(properties);
139-
}
102+
return schema.toBuilder()
103+
.description(param.getAnnotation(Annotations.Schema.class).description())
104+
.build();
140105
}
141-
return builder.build();
106+
return schema;
142107
}
143108

144109
public static Schema buildSchemaFromType(Type type) {
145110
Schema.Builder builder = Schema.builder();
146111
if (type instanceof ParameterizedType parameterizedType) {
147-
switch (((Class<?>) parameterizedType.getRawType()).getName()) {
148-
case "java.util.List" ->
149-
builder
150-
.type("ARRAY")
151-
.items(buildSchemaFromType(parameterizedType.getActualTypeArguments()[0]));
152-
case "java.util.Map", "com.google.common.collect.ImmutableMap" -> builder.type("OBJECT");
153-
default -> throw new IllegalArgumentException("Unsupported generic type: " + type);
112+
String rawTypeName = ((Class<?>) parameterizedType.getRawType()).getName();
113+
switch (rawTypeName) {
114+
case "java.util.List", "com.google.common.collect.ImmutableList":
115+
Schema itemSchema = buildSchemaFromType(parameterizedType.getActualTypeArguments()[0]);
116+
builder.type("ARRAY").items(itemSchema);
117+
break;
118+
case "java.util.Map":
119+
case "com.google.common.collect.ImmutableMap":
120+
builder.type("OBJECT");
121+
break;
122+
default:
123+
throw new IllegalArgumentException("Unsupported generic type: " + type);
154124
}
155125
} else if (type instanceof Class<?> clazz) {
156126
switch (clazz.getName()) {

core/src/main/java/com/google/adk/tools/FunctionTool.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.adk.tools;
1818

19+
import com.fasterxml.jackson.core.type.TypeReference;
1920
import com.fasterxml.jackson.databind.ObjectMapper;
2021
import com.google.common.collect.ImmutableList;
2122
import com.google.common.collect.ImmutableMap;
@@ -196,11 +197,18 @@ private Maybe<Map<String, Object>> call(Map<String, Object> args, ToolContext to
196197
if (result == null) {
197198
return Maybe.empty();
198199
} else if (result instanceof Maybe) {
199-
return (Maybe<Map<String, Object>>) result;
200+
return ((Maybe<?>) result)
201+
.map(
202+
data ->
203+
OBJECT_MAPPER.convertValue(data, new TypeReference<Map<String, Object>>() {}));
200204
} else if (result instanceof Single) {
201-
return ((Single<Map<String, Object>>) result).toMaybe();
205+
return ((Single<?>) result)
206+
.map(
207+
data -> OBJECT_MAPPER.convertValue(data, new TypeReference<Map<String, Object>>() {}))
208+
.toMaybe();
202209
} else {
203-
return Maybe.just((Map<String, Object>) result);
210+
return Maybe.just(
211+
OBJECT_MAPPER.convertValue(result, new TypeReference<Map<String, Object>>() {}));
204212
}
205213
}
206214

core/src/test/java/com/google/adk/tools/FunctionToolTest.java

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -337,13 +337,6 @@ public void call_withMaybeMapReturnType() throws Exception {
337337
assertThat(result).containsExactly("key", "value");
338338
}
339339

340-
@Test
341-
public void create_withMaybeStringReturnType() {
342-
assertThrows(
343-
IllegalArgumentException.class,
344-
() -> FunctionTool.create(Functions.class, "returnsMaybeString"));
345-
}
346-
347340
@Test
348341
public void create_withSingleMapReturnType() {
349342
FunctionTool tool = FunctionTool.create(Functions.class, "returnsSingleMap");
@@ -362,6 +355,27 @@ public void call_withSingleMapReturnType() throws Exception {
362355
assertThat(result).containsExactly("key", "value");
363356
}
364357

358+
@Test
359+
public void call_withPojoReturnType() throws Exception {
360+
FunctionTool tool = FunctionTool.create(Functions.class, "returnsPojo");
361+
Map<String, Object> result = tool.runAsync(ImmutableMap.of(), null).blockingGet();
362+
assertThat(result).containsExactly("field1", "abc", "field2", 123);
363+
}
364+
365+
@Test
366+
public void call_withSinglePojoReturnType() throws Exception {
367+
FunctionTool tool = FunctionTool.create(Functions.class, "returnsSinglePojo");
368+
Map<String, Object> result = tool.runAsync(ImmutableMap.of(), null).blockingGet();
369+
assertThat(result).containsExactly("field1", "abc", "field2", 123);
370+
}
371+
372+
@Test
373+
public void call_withMaybePojoReturnType() throws Exception {
374+
FunctionTool tool = FunctionTool.create(Functions.class, "returnsMaybePojo");
375+
Map<String, Object> result = tool.runAsync(ImmutableMap.of(), null).blockingGet();
376+
assertThat(result).containsExactly("field1", "abc", "field2", 123);
377+
}
378+
365379
@Test
366380
public void call_nonStaticWithAllSupportedParameterTypes() throws Exception {
367381
Functions functions = new Functions();
@@ -486,6 +500,27 @@ public static Single<Map<String, Object>> returnsSingleMap() {
486500
return Single.just(ImmutableMap.of("key", "value"));
487501
}
488502

503+
public static PojoWithGettersAndSetters returnsPojo() {
504+
PojoWithGettersAndSetters pojo = new PojoWithGettersAndSetters();
505+
pojo.setField1("abc");
506+
pojo.setField2(123);
507+
return pojo;
508+
}
509+
510+
public static Single<PojoWithGettersAndSetters> returnsSinglePojo() {
511+
PojoWithGettersAndSetters pojo = new PojoWithGettersAndSetters();
512+
pojo.setField1("abc");
513+
pojo.setField2(123);
514+
return Single.just(pojo);
515+
}
516+
517+
public static Maybe<PojoWithGettersAndSetters> returnsMaybePojo() {
518+
PojoWithGettersAndSetters pojo = new PojoWithGettersAndSetters();
519+
pojo.setField1("abc");
520+
pojo.setField2(123);
521+
return Maybe.just(pojo);
522+
}
523+
489524
public void nonStaticVoidReturnWithoutSchema() {}
490525

491526
public ImmutableMap<String, Object> nonStaticReturnAllSupportedParametersAsMap(

0 commit comments

Comments
 (0)