Skip to content

Commit f550839

Browse files
tzolovnamsoo2
authored andcommitted
feat: Add support for generic argument types in tool callbacks
- Enhance MethodToolCallback to properly handle generic types by using parameterized types - Add unit tests for generic type handling (List, Map<String,Integer>, nested generics) - Add integration tests for both Anthropic and OpenAI clients to verify tool calls with generic argument types Resolves spring-projects#2462 Signed-off-by: Christian Tzolov <[email protected]> Signed-off-by: minsoo.nam <[email protected]>
1 parent 7e29bff commit f550839

File tree

5 files changed

+375
-5
lines changed

5 files changed

+375
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.anthropic.client;
18+
19+
import java.util.List;
20+
import java.util.Map;
21+
import java.util.concurrent.ConcurrentHashMap;
22+
import java.util.concurrent.atomic.AtomicLong;
23+
24+
import org.junit.jupiter.api.BeforeEach;
25+
import org.junit.jupiter.api.Test;
26+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
27+
import org.slf4j.Logger;
28+
import org.slf4j.LoggerFactory;
29+
30+
import org.springframework.ai.anthropic.AnthropicTestConfiguration;
31+
import org.springframework.ai.chat.client.ChatClient;
32+
import org.springframework.ai.chat.model.ChatModel;
33+
import org.springframework.ai.tool.annotation.Tool;
34+
import org.springframework.ai.tool.annotation.ToolParam;
35+
import org.springframework.beans.factory.annotation.Autowired;
36+
import org.springframework.boot.test.context.SpringBootTest;
37+
38+
import static org.assertj.core.api.Assertions.assertThat;
39+
40+
@SpringBootTest(classes = AnthropicTestConfiguration.class)
41+
@EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+")
42+
class ChatClientToolsWithGenericArgumentTypesIT {
43+
44+
private static final Logger logger = LoggerFactory.getLogger(ChatClientToolsWithGenericArgumentTypesIT.class);
45+
46+
public static Map<String, Object> arguments = new ConcurrentHashMap<>();
47+
48+
public static AtomicLong callCounter = new AtomicLong(0);
49+
50+
@BeforeEach
51+
void beforeEach() {
52+
arguments.clear();
53+
}
54+
55+
@Autowired
56+
ChatModel chatModel;
57+
58+
@Test
59+
void toolWithGenericArgumentTypes() {
60+
// @formatter:off
61+
String response = ChatClient.create(this.chatModel).prompt()
62+
.user("Turn light red in the living room and the kitchen. Please group the romms with the same color in a single tool call.")
63+
.tools(new TestToolProvider())
64+
.call()
65+
.content();
66+
// @formatter:on
67+
68+
logger.info("Response: {}", response);
69+
70+
assertThat(arguments).containsEntry("living room", LightColor.RED);
71+
assertThat(arguments).containsEntry("kitchen", LightColor.RED);
72+
73+
assertThat(callCounter.get()).isEqualTo(1);
74+
}
75+
76+
record Room(String name) {
77+
}
78+
79+
enum LightColor {
80+
81+
RED, GREEN, BLUE
82+
83+
}
84+
85+
public static class TestToolProvider {
86+
87+
@Tool(description = "Change the lamp color in a room.")
88+
public void changeRoomLightColor(
89+
@ToolParam(description = "List of rooms to change the ligth color for") List<Room> rooms,
90+
@ToolParam(description = "light color to change to") LightColor color) {
91+
92+
logger.info("Change light color in rooms: {} to color: {}", rooms, color);
93+
94+
for (Room room : rooms) {
95+
arguments.put(room.name(), color);
96+
}
97+
callCounter.incrementAndGet();
98+
}
99+
100+
}
101+
102+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.chat.client;
18+
19+
import java.util.List;
20+
import java.util.Map;
21+
import java.util.concurrent.ConcurrentHashMap;
22+
import java.util.concurrent.atomic.AtomicLong;
23+
24+
import org.junit.jupiter.api.BeforeEach;
25+
import org.junit.jupiter.api.Test;
26+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
27+
import org.slf4j.Logger;
28+
import org.slf4j.LoggerFactory;
29+
30+
import org.springframework.ai.chat.client.ChatClient;
31+
import org.springframework.ai.chat.model.ChatModel;
32+
import org.springframework.ai.openai.OpenAiTestConfiguration;
33+
import org.springframework.ai.tool.annotation.Tool;
34+
import org.springframework.ai.tool.annotation.ToolParam;
35+
import org.springframework.beans.factory.annotation.Autowired;
36+
import org.springframework.boot.test.context.SpringBootTest;
37+
import org.springframework.test.context.ActiveProfiles;
38+
39+
import static org.assertj.core.api.Assertions.assertThat;
40+
41+
@SpringBootTest(classes = OpenAiTestConfiguration.class)
42+
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
43+
@ActiveProfiles("logging-test")
44+
class ChatClientToolsWithGenericArgumentTypesIT {
45+
46+
private static final Logger logger = LoggerFactory.getLogger(ChatClientToolsWithGenericArgumentTypesIT.class);
47+
48+
public static Map<String, Object> arguments = new ConcurrentHashMap<>();
49+
50+
public static AtomicLong callCounter = new AtomicLong(0);
51+
52+
@BeforeEach
53+
void beforeEach() {
54+
arguments.clear();
55+
}
56+
57+
@Autowired
58+
ChatModel chatModel;
59+
60+
@Test
61+
void toolWithGenericArgumentTypes() {
62+
// @formatter:off
63+
String response = ChatClient.create(this.chatModel).prompt()
64+
.user("Turn light red in the living room and the kitchen. Please group the romms with the same color in a single tool call.")
65+
.tools(new TestToolProvider())
66+
.call()
67+
.content();
68+
// @formatter:on
69+
70+
logger.info("Response: {}", response);
71+
72+
assertThat(arguments).containsEntry("living room", LightColor.RED);
73+
assertThat(arguments).containsEntry("kitchen", LightColor.RED);
74+
75+
assertThat(callCounter.get()).isEqualTo(1);
76+
}
77+
78+
record Room(String name) {
79+
}
80+
81+
enum LightColor {
82+
83+
RED, GREEN, BLUE
84+
85+
}
86+
87+
public static class TestToolProvider {
88+
89+
@Tool(description = "Change the lamp color in a room.")
90+
public void changeRoomLightColor(
91+
@ToolParam(description = "List of rooms to change the ligth color for") List<Room> rooms,
92+
@ToolParam(description = "light color to change to") LightColor color) {
93+
94+
logger.info("Change light color in rooms: {} to color: {}", rooms, color);
95+
96+
for (Room room : rooms) {
97+
arguments.put(room.name(), color);
98+
}
99+
callCounter.incrementAndGet();
100+
}
101+
102+
}
103+
104+
}

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMemoryAdvisorReproIT.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import org.springframework.boot.test.context.SpringBootTest;
1919
import org.springframework.test.context.ActiveProfiles;
2020

21-
import static org.assertj.core.api.Assertions.assertThatThrownBy;
22-
2321
@SpringBootTest(classes = OpenAiTestConfiguration.class)
2422
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
2523
@ActiveProfiles("logging-test")

spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,16 +136,23 @@ private Object[] buildMethodArguments(Map<String, Object> toolInputArguments, @N
136136
return toolContext;
137137
}
138138
Object rawArgument = toolInputArguments.get(parameter.getName());
139-
return buildTypedArgument(rawArgument, parameter.getType());
139+
return buildTypedArgument(rawArgument, parameter.getParameterizedType());
140140
}).toArray();
141141
}
142142

143143
@Nullable
144-
private Object buildTypedArgument(@Nullable Object value, Class<?> type) {
144+
private Object buildTypedArgument(@Nullable Object value, Type type) {
145145
if (value == null) {
146146
return null;
147147
}
148-
return JsonParser.toTypedObject(value, type);
148+
149+
if (type instanceof Class<?>) {
150+
return JsonParser.toTypedObject(value, (Class<?>) type);
151+
}
152+
153+
// For generic types, use the fromJson method that accepts Type
154+
String json = JsonParser.toJson(value);
155+
return JsonParser.fromJson(json, type);
149156
}
150157

151158
@Nullable

0 commit comments

Comments
 (0)