Skip to content

Commit 646f6f7

Browse files
committed
support audio
1 parent 9a0be0f commit 646f6f7

File tree

4 files changed

+86
-32
lines changed

4 files changed

+86
-32
lines changed

src/main/java/io/github/alexcheng1982/springai/dashscope/DashscopeChatClient.java

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
import reactor.core.publisher.Flux;
4242

4343
/**
44-
* Spring AI {@linkplain ChatClient} and {@linkplain StreamingChatClient} for Aliyun Dashscope
44+
* Spring AI {@linkplain ChatClient} and {@linkplain StreamingChatClient} for
45+
* Aliyun Dashscope
4546
*/
4647
public class DashscopeChatClient extends
4748
AbstractFunctionCallSupport<ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResult> implements
@@ -97,33 +98,38 @@ public Flux<ChatResponse> stream(Prompt prompt) {
9798
var request = createRequest(prompt);
9899
if (request.isMultiModalRequest()) {
99100
return RxJava2Adapter.flowableToFlux(
100-
dashscopeApi.multiModalStream(request.getMultiModalMessages(), request.options())
101+
dashscopeApi.multiModalStream(request.getMultiModalMessages(),
102+
request.options())
101103
.map(result -> {
102104
var response = handleFunctionCallOrReturn(request,
103105
new ChatCompletionResult(result));
104106
return chatCompletionResultToChatResponse(response);
105107
}));
106108
}
107109
return RxJava2Adapter.flowableToFlux(
108-
dashscopeApi.chatCompletionStream(request.getMessages(), request.options())
110+
dashscopeApi.chatCompletionStream(request.getMessages(),
111+
request.options())
109112
.map(result -> {
110113
var response = handleFunctionCallOrReturn(request,
111114
new ChatCompletionResult(result));
112115
return chatCompletionResultToChatResponse(response);
113116
}));
114117
}
115118

116-
private ChatResponse chatCompletionResultToChatResponse(ChatCompletionResult result) {
119+
private ChatResponse chatCompletionResultToChatResponse(
120+
ChatCompletionResult result) {
117121
if (result.multiModalConversationResult() != null) {
118-
return multiModalConversationResultToChatResponse(result.multiModalConversationResult());
122+
return multiModalConversationResultToChatResponse(
123+
result.multiModalConversationResult());
119124
} else {
120125
return generationResultToChatResponse(result.generationResult());
121126
}
122127
}
123128

124129
private ChatResponse multiModalConversationResultToChatResponse(
125130
MultiModalConversationResult result) {
126-
List<org.springframework.ai.chat.Generation> generations = result.getOutput().getChoices()
131+
List<org.springframework.ai.chat.Generation> generations = result.getOutput()
132+
.getChoices()
127133
.stream()
128134
.map(choice -> new org.springframework.ai.chat.Generation(
129135
(String) choice.getMessage().getContent().get(0).get("text"))
@@ -211,7 +217,8 @@ private ToolBase toToolFunction(FunctionCallback functionCallback) {
211217

212218
@Override
213219
protected ChatCompletionRequest doCreateToolResponseRequest(
214-
ChatCompletionRequest previousRequest, ChatCompletionMessage responseMessage,
220+
ChatCompletionRequest previousRequest,
221+
ChatCompletionMessage responseMessage,
215222
List<ChatCompletionMessage> conversationHistory) {
216223
if (responseMessage.message() != null) {
217224
for (ToolCallBase toolCall : responseMessage.message().getToolCalls()) {
@@ -223,7 +230,8 @@ protected ChatCompletionRequest doCreateToolResponseRequest(
223230

224231
if (!this.functionCallbackRegister.containsKey(functionName)) {
225232
throw new IllegalStateException(
226-
"No function callback found for function name: " + functionName);
233+
"No function callback found for function name: "
234+
+ functionName);
227235
}
228236

229237
String functionResponse = this.functionCallbackRegister.get(
@@ -244,29 +252,36 @@ protected ChatCompletionRequest doCreateToolResponseRequest(
244252
}
245253

246254
@Override
247-
protected List<ChatCompletionMessage> doGetUserMessages(ChatCompletionRequest request) {
255+
protected List<ChatCompletionMessage> doGetUserMessages(
256+
ChatCompletionRequest request) {
248257
return request.messages();
249258
}
250259

251260
@Override
252-
protected ChatCompletionMessage doGetToolResponseMessage(ChatCompletionResult response) {
261+
protected ChatCompletionMessage doGetToolResponseMessage(
262+
ChatCompletionResult response) {
253263
if (response.generationResult() != null) {
254264
return new ChatCompletionMessage(
255-
response.generationResult().getOutput().getChoices().get(0).getMessage());
265+
response.generationResult().getOutput().getChoices().get(0)
266+
.getMessage());
256267
} else {
257268
return new ChatCompletionMessage(
258-
response.multiModalConversationResult().getOutput().getChoices().get(0).getMessage());
269+
response.multiModalConversationResult().getOutput().getChoices()
270+
.get(0).getMessage());
259271
}
260272
}
261273

262274
@Override
263-
protected ChatCompletionResult doChatCompletion(ChatCompletionRequest request) {
275+
protected ChatCompletionResult doChatCompletion(
276+
ChatCompletionRequest request) {
264277
if (request.isMultiModalRequest()) {
265278
return new ChatCompletionResult(
266-
this.dashscopeApi.multiModal(request.getMultiModalMessages(), request.options()));
279+
this.dashscopeApi.multiModal(request.getMultiModalMessages(),
280+
request.options()));
267281
} else {
268-
return new ChatCompletionResult(this.dashscopeApi.chatCompletion(request.getMessages(),
269-
request.options()));
282+
return new ChatCompletionResult(
283+
this.dashscopeApi.chatCompletion(request.getMessages(),
284+
request.options()));
270285
}
271286
}
272287

@@ -285,11 +300,13 @@ protected boolean isToolFunctionCall(ChatCompletionResult response) {
285300

286301
private List<ChatCompletionMessage> toDashscopeMessages(
287302
List<org.springframework.ai.chat.messages.Message> messages) {
288-
if (messages.stream().anyMatch(message -> !CollectionUtils.isEmpty(message.getMedia()))) {
303+
if (messages.stream()
304+
.anyMatch(message -> !CollectionUtils.isEmpty(message.getMedia()))) {
289305
return messages.stream().map(this::toDashscopeMultiModalMessage)
290306
.map(ChatCompletionMessage::new).toList();
291307
} else {
292-
return messages.stream().map(this::toDashscopeMessage).map(ChatCompletionMessage::new)
308+
return messages.stream().map(this::toDashscopeMessage)
309+
.map(ChatCompletionMessage::new)
293310
.toList();
294311
}
295312
}
@@ -304,9 +321,10 @@ private Message toDashscopeMessage(
304321

305322
private MultiModalMessage toDashscopeMultiModalMessage(
306323
org.springframework.ai.chat.messages.Message message) {
307-
var images = message.getMedia().stream().map(media -> new HashMap<String, Object>() {{
308-
put("image", media.getData());
309-
}}).toList();
324+
var images = message.getMedia().stream()
325+
.map(media -> new HashMap<String, Object>() {{
326+
put(media.getMimeType().getType(), media.getData());
327+
}}).toList();
310328
var content = new ArrayList<Map<String, Object>>(images);
311329
content.add(new HashMap<>() {{
312330
put("text", message.getContent());

src/main/java/io/github/alexcheng1982/springai/dashscope/api/DashscopeApi.java

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ private GenerationParam buildGenerationParam(List<Message> messages,
7676
return builder.build();
7777
}
7878

79-
public MultiModalConversationResult multiModal(List<MultiModalMessage> messages,
79+
public MultiModalConversationResult multiModal(
80+
List<MultiModalMessage> messages,
8081
DashscopeChatOptions options) {
8182
try {
8283
return multiModalConversation.call(
@@ -86,7 +87,8 @@ public MultiModalConversationResult multiModal(List<MultiModalMessage> messages,
8687
}
8788
}
8889

89-
public Flowable<MultiModalConversationResult> multiModalStream(List<MultiModalMessage> messages,
90+
public Flowable<MultiModalConversationResult> multiModalStream(
91+
List<MultiModalMessage> messages,
9092
DashscopeChatOptions options) {
9193
try {
9294
return multiModalConversation.streamCall(
@@ -97,7 +99,8 @@ public Flowable<MultiModalConversationResult> multiModalStream(List<MultiModalMe
9799
}
98100

99101
private MultiModalConversationParam buildMultiModalConversationParam(
100-
List<MultiModalMessage> messages, DashscopeChatOptions options, boolean streaming) {
102+
List<MultiModalMessage> messages, DashscopeChatOptions options,
103+
boolean streaming) {
101104
return MultiModalConversationParam.builder()
102105
.model(options.getModel())
103106
.messages(messages)
@@ -111,6 +114,12 @@ private MultiModalConversationParam buildMultiModalConversationParam(
111114
.build();
112115
}
113116

117+
/**
118+
* Union type of {@linkplain Message} and {@linkplain MultiModalMessage}
119+
*
120+
* @param message Message
121+
* @param multiModalMessage MultiModalMessage
122+
*/
114123
public record ChatCompletionMessage(
115124
Message message,
116125
MultiModalMessage multiModalMessage) {
@@ -129,7 +138,8 @@ public record ChatCompletionRequest(
129138
DashscopeChatOptions options) {
130139

131140
public boolean isMultiModalRequest() {
132-
return messages.stream().anyMatch(message -> message.multiModalMessage() != null);
141+
return messages.stream()
142+
.anyMatch(message -> message.multiModalMessage() != null);
133143
}
134144

135145
public List<Message> getMessages() {
@@ -143,6 +153,13 @@ public List<MultiModalMessage> getMultiModalMessages() {
143153
}
144154
}
145155

156+
/**
157+
* Union type of {@linkplain GenerationResult} and
158+
* {@linkplain MultiModalConversationResult}
159+
*
160+
* @param generationResult GenerationResult
161+
* @param multiModalConversationResult MultiModalConversationResult
162+
*/
146163
public record ChatCompletionResult(
147164
GenerationResult generationResult,
148165
MultiModalConversationResult multiModalConversationResult) {
@@ -151,7 +168,8 @@ public ChatCompletionResult(GenerationResult generationResult) {
151168
this(generationResult, null);
152169
}
153170

154-
public ChatCompletionResult(MultiModalConversationResult multiModalConversationResult) {
171+
public ChatCompletionResult(
172+
MultiModalConversationResult multiModalConversationResult) {
155173
this(null, multiModalConversationResult);
156174
}
157175
}

src/main/java/io/github/alexcheng1982/springai/dashscope/api/DashscopeModelName.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ public class DashscopeModelName {
1010
public static final String QWEN_14B_CHAT = "qwen-14b-chat"; // Qwen open sourced 14-billion-parameters version
1111
public static final String QWEN_VL_PLUS = "qwen-vl-plus"; // Qwen multi-modal model, supports image and text information.
1212
public static final String QWEN_VL_MAX = "qwen-vl-max"; // Qwen multi-modal model, offers optimal performance on a wider range of complex tasks.
13+
public static final String QWEN_AUDIO_TURBO = "qwen-audio-turbo";
1314

1415
// Text embedding models
1516
public static final String TEXT_EMBEDDING_V1 = "text-embedding-v1"; // Support: en, zh, es, fr, pt, id

src/test/java/io/github/alexcheng1982/springai/dashscope/DashscopeChatClientTest.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import static org.junit.jupiter.api.Assertions.assertNotNull;
44
import static org.springframework.util.MimeTypeUtils.IMAGE_JPEG;
55

6-
import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversation.Models;
6+
import io.github.alexcheng1982.springai.dashscope.api.DashscopeModelName;
77
import java.util.List;
88
import org.junit.jupiter.api.Test;
99
import org.springframework.ai.chat.messages.Media;
1010
import org.springframework.ai.chat.messages.UserMessage;
1111
import org.springframework.ai.chat.prompt.Prompt;
12+
import org.springframework.util.MimeType;
1213

1314
/**
1415
* This test requires a Dashscope API key
@@ -31,13 +32,29 @@ void streamSmokeTest() {
3132
}
3233

3334
@Test
34-
void multiModalSmokeTest() {
35+
void multiModalImageSmokeTest() {
3536
var client = DashscopeChatClient.createDefault();
36-
var prompt = new Prompt(new UserMessage("这是什么?", List.of(
37-
new Media(IMAGE_JPEG,
38-
"https://dashscope.oss-cn-beijing.aliyuncs.com/images/dog_and_girl.jpeg"))),
37+
var prompt = new Prompt(new UserMessage("这是什么?",
38+
List.of(
39+
new Media(IMAGE_JPEG,
40+
"https://dashscope.oss-cn-beijing.aliyuncs.com/images/dog_and_girl.jpeg"))),
3941
DashscopeChatOptions.builder()
40-
.withModel(Models.QWEN_VL_PLUS)
42+
.withModel(DashscopeModelName.QWEN_VL_PLUS)
43+
.build());
44+
var response = client.call(prompt);
45+
System.out.println(response.getResult().getOutput().getContent());
46+
}
47+
48+
@Test
49+
void multiModalAudioSmokeTest() {
50+
var client = DashscopeChatClient.createDefault();
51+
var prompt = new Prompt(new UserMessage("这段音频在说什么?",
52+
List.of(
53+
new Media(new MimeType("audio", "wav"),
54+
"https://dashscope.oss-cn-beijing.aliyuncs.com/audios/2channel_16K.wav"))),
55+
DashscopeChatOptions.builder()
56+
.withModel(DashscopeModelName.QWEN_AUDIO_TURBO)
57+
.withMaxTokens(100)
4158
.build());
4259
var response = client.call(prompt);
4360
System.out.println(response.getResult().getOutput().getContent());

0 commit comments

Comments
 (0)