Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 24 additions & 54 deletions core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,26 +70,15 @@ public static FunctionDeclaration buildFunctionDeclaration(

Type returnType = func.getGenericReturnType();
if (returnType != Void.TYPE) {
Type realReturnType = returnType;
if (returnType instanceof ParameterizedType) {
ParameterizedType parameterizedReturnType = (ParameterizedType) returnType;
String returnTypeName = ((Class<?>) parameterizedReturnType.getRawType()).getName();
if (returnTypeName.equals("io.reactivex.rxjava3.core.Maybe")
|| returnTypeName.equals("io.reactivex.rxjava3.core.Single")) {
returnType = parameterizedReturnType.getActualTypeArguments()[0];
if (returnType instanceof ParameterizedType) {
ParameterizedType maybeParameterizedType = (ParameterizedType) returnType;
returnTypeName = ((Class<?>) maybeParameterizedType.getRawType()).getName();
}
}
if (returnTypeName.equals("java.util.Map")
|| returnTypeName.equals("com.google.common.collect.ImmutableMap")) {
return builder.response(buildSchemaFromType(returnType)).build();
Type actualReturnType = returnType;
if (returnType instanceof ParameterizedType parameterizedReturnType) {
String rawTypeName = ((Class<?>) parameterizedReturnType.getRawType()).getName();
if (rawTypeName.equals("io.reactivex.rxjava3.core.Maybe")
|| rawTypeName.equals("io.reactivex.rxjava3.core.Single")) {
actualReturnType = parameterizedReturnType.getActualTypeArguments()[0];
}
}
throw new IllegalArgumentException(
"Return type should be Map or Maybe<Map> or Single<Map>, but it was "
+ realReturnType.getTypeName());
builder.response(buildSchemaFromType(actualReturnType));
}
return builder.build();
}
Expand All @@ -107,50 +96,31 @@ static FunctionDeclaration buildFunctionDeclaration(JsonBaseModel func, String d
}

private static Schema buildSchemaFromParameter(Parameter param) {
Schema.Builder builder = Schema.builder();
Schema schema = buildSchemaFromType(param.getParameterizedType());
if (param.isAnnotationPresent(Annotations.Schema.class)
&& !param.getAnnotation(Annotations.Schema.class).description().isEmpty()) {
builder.description(param.getAnnotation(Annotations.Schema.class).description());
}
switch (param.getType().getName()) {
case "java.lang.String" -> builder.type("STRING");
case "boolean", "java.lang.Boolean" -> builder.type("BOOLEAN");
case "int", "java.lang.Integer" -> builder.type("INTEGER");
case "double", "java.lang.Double", "float", "java.lang.Float", "long", "java.lang.Long" ->
builder.type("NUMBER");
case "java.util.List" ->
builder
.type("ARRAY")
.items(
buildSchemaFromType(
((ParameterizedType) param.getParameterizedType())
.getActualTypeArguments()[0]));
case "java.util.Map" -> builder.type("OBJECT");
default -> {
BeanDescription beanDescription =
OBJECT_MAPPER
.getSerializationConfig()
.introspect(OBJECT_MAPPER.constructType(param.getType()));
Map<String, Schema> properties = new LinkedHashMap<>();
for (BeanPropertyDefinition property : beanDescription.findProperties()) {
properties.put(property.getName(), buildSchemaFromType(property.getRawPrimaryType()));
}
builder.type("OBJECT").properties(properties);
}
return schema.toBuilder()
.description(param.getAnnotation(Annotations.Schema.class).description())
.build();
}
return builder.build();
return schema;
}

public static Schema buildSchemaFromType(Type type) {
Schema.Builder builder = Schema.builder();
if (type instanceof ParameterizedType parameterizedType) {
switch (((Class<?>) parameterizedType.getRawType()).getName()) {
case "java.util.List" ->
builder
.type("ARRAY")
.items(buildSchemaFromType(parameterizedType.getActualTypeArguments()[0]));
case "java.util.Map", "com.google.common.collect.ImmutableMap" -> builder.type("OBJECT");
default -> throw new IllegalArgumentException("Unsupported generic type: " + type);
String rawTypeName = ((Class<?>) parameterizedType.getRawType()).getName();
switch (rawTypeName) {
case "java.util.List", "com.google.common.collect.ImmutableList":
Schema itemSchema = buildSchemaFromType(parameterizedType.getActualTypeArguments()[0]);
builder.type("ARRAY").items(itemSchema);
break;
case "java.util.Map":
case "com.google.common.collect.ImmutableMap":
builder.type("OBJECT");
break;
default:
throw new IllegalArgumentException("Unsupported generic type: " + type);
}
} else if (type instanceof Class<?> clazz) {
switch (clazz.getName()) {
Expand Down
14 changes: 11 additions & 3 deletions core/src/main/java/com/google/adk/tools/FunctionTool.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.adk.tools;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -196,11 +197,18 @@ private Maybe<Map<String, Object>> call(Map<String, Object> args, ToolContext to
if (result == null) {
return Maybe.empty();
} else if (result instanceof Maybe) {
return (Maybe<Map<String, Object>>) result;
return ((Maybe<?>) result)
.map(
data ->
OBJECT_MAPPER.convertValue(data, new TypeReference<Map<String, Object>>() {}));
} else if (result instanceof Single) {
return ((Single<Map<String, Object>>) result).toMaybe();
return ((Single<?>) result)
.map(
data -> OBJECT_MAPPER.convertValue(data, new TypeReference<Map<String, Object>>() {}))
.toMaybe();
} else {
return Maybe.just((Map<String, Object>) result);
return Maybe.just(
OBJECT_MAPPER.convertValue(result, new TypeReference<Map<String, Object>>() {}));
}
}

Expand Down
49 changes: 42 additions & 7 deletions core/src/test/java/com/google/adk/tools/FunctionToolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,6 @@ public void call_withMaybeMapReturnType() throws Exception {
assertThat(result).containsExactly("key", "value");
}

@Test
public void create_withMaybeStringReturnType() {
assertThrows(
IllegalArgumentException.class,
() -> FunctionTool.create(Functions.class, "returnsMaybeString"));
}

@Test
public void create_withSingleMapReturnType() {
FunctionTool tool = FunctionTool.create(Functions.class, "returnsSingleMap");
Expand All @@ -362,6 +355,27 @@ public void call_withSingleMapReturnType() throws Exception {
assertThat(result).containsExactly("key", "value");
}

@Test
public void call_withPojoReturnType() throws Exception {
FunctionTool tool = FunctionTool.create(Functions.class, "returnsPojo");
Map<String, Object> result = tool.runAsync(ImmutableMap.of(), null).blockingGet();
assertThat(result).containsExactly("field1", "abc", "field2", 123);
}

@Test
public void call_withSinglePojoReturnType() throws Exception {
FunctionTool tool = FunctionTool.create(Functions.class, "returnsSinglePojo");
Map<String, Object> result = tool.runAsync(ImmutableMap.of(), null).blockingGet();
assertThat(result).containsExactly("field1", "abc", "field2", 123);
}

@Test
public void call_withMaybePojoReturnType() throws Exception {
FunctionTool tool = FunctionTool.create(Functions.class, "returnsMaybePojo");
Map<String, Object> result = tool.runAsync(ImmutableMap.of(), null).blockingGet();
assertThat(result).containsExactly("field1", "abc", "field2", 123);
}

@Test
public void call_nonStaticWithAllSupportedParameterTypes() throws Exception {
Functions functions = new Functions();
Expand Down Expand Up @@ -486,6 +500,27 @@ public static Single<Map<String, Object>> returnsSingleMap() {
return Single.just(ImmutableMap.of("key", "value"));
}

public static PojoWithGettersAndSetters returnsPojo() {
PojoWithGettersAndSetters pojo = new PojoWithGettersAndSetters();
pojo.setField1("abc");
pojo.setField2(123);
return pojo;
}

public static Single<PojoWithGettersAndSetters> returnsSinglePojo() {
PojoWithGettersAndSetters pojo = new PojoWithGettersAndSetters();
pojo.setField1("abc");
pojo.setField2(123);
return Single.just(pojo);
}

public static Maybe<PojoWithGettersAndSetters> returnsMaybePojo() {
PojoWithGettersAndSetters pojo = new PojoWithGettersAndSetters();
pojo.setField1("abc");
pojo.setField2(123);
return Maybe.just(pojo);
}

public void nonStaticVoidReturnWithoutSchema() {}

public ImmutableMap<String, Object> nonStaticReturnAllSupportedParametersAsMap(
Expand Down