Skip to content

Commit bf7b515

Browse files
committed
feat: Add StreamingChatClient
1 parent 21cc7a1 commit bf7b515

File tree

7 files changed

+82
-36
lines changed

7 files changed

+82
-36
lines changed

pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@
7171
<artifactId>jackson-databind</artifactId>
7272
<version>${jackson.version}</version>
7373
</dependency>
74+
<dependency>
75+
<groupId>io.projectreactor.addons</groupId>
76+
<artifactId>reactor-adapter</artifactId>
77+
<version>3.5.1</version>
78+
</dependency>
7479

7580
<dependency>
7681
<groupId>org.junit.jupiter</groupId>

src/main/java/io/github/alexcheng1982/springai/dashscope/DashscopeChatClient.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.Set;
2020
import org.springframework.ai.chat.ChatClient;
2121
import org.springframework.ai.chat.ChatResponse;
22+
import org.springframework.ai.chat.StreamingChatClient;
2223
import org.springframework.ai.chat.messages.MessageType;
2324
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
2425
import org.springframework.ai.chat.prompt.ChatOptions;
@@ -29,13 +30,16 @@
2930
import org.springframework.ai.model.function.FunctionCallbackContext;
3031
import org.springframework.util.Assert;
3132
import org.springframework.util.CollectionUtils;
33+
import reactor.adapter.rxjava.RxJava2Adapter;
34+
import reactor.core.publisher.Flux;
3235

3336
/**
34-
* Spring AI {@linkplain ChatClient} for Aliyun Dashscope
37+
* Spring AI {@linkplain ChatClient} and {@linkplain StreamingChatClient} for
38+
* Aliyun Dashscope
3539
*/
3640
public class DashscopeChatClient extends
3741
AbstractFunctionCallSupport<Message, ChatCompletionRequest, GenerationResult> implements
38-
ChatClient {
42+
ChatClient, StreamingChatClient {
3943

4044
private static final DashscopeChatOptions DEFAULT_OPTIONS = DashscopeChatOptions.builder()
4145
.withModel(DashscopeModelName.QWEN_MAX)
@@ -79,6 +83,23 @@ public static DashscopeChatClient createDefault() {
7983
@Override
8084
public ChatResponse call(Prompt prompt) {
8185
var generationResult = callWithFunctionSupport(createRequest(prompt));
86+
return generationResultToChatResponse(generationResult);
87+
}
88+
89+
@Override
90+
public Flux<ChatResponse> stream(Prompt prompt) {
91+
var request = createRequest(prompt);
92+
return RxJava2Adapter.flowableToFlux(
93+
dashscopeApi.chatCompletionStream(request.messages(), request.options())
94+
.map(result -> {
95+
var generationResult = handleFunctionCallOrReturn(request,
96+
result);
97+
return generationResultToChatResponse(generationResult);
98+
}));
99+
}
100+
101+
private ChatResponse generationResultToChatResponse(
102+
GenerationResult generationResult) {
82103
List<org.springframework.ai.chat.Generation> generations = generationResult.getOutput()
83104
.getChoices().stream()
84105
.map(choice -> new org.springframework.ai.chat.Generation(

src/main/java/io/github/alexcheng1982/springai/dashscope/DashscopeEmbeddingClient.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import com.alibaba.dashscope.embeddings.TextEmbedding;
44
import com.alibaba.dashscope.embeddings.TextEmbeddingParam;
55
import com.alibaba.dashscope.embeddings.TextEmbeddingParam.TextType;
6+
import com.alibaba.dashscope.exception.ApiException;
67
import com.alibaba.dashscope.exception.NoApiKeyException;
8+
import io.github.alexcheng1982.springai.dashscope.api.DashscopeApiException;
79
import io.github.alexcheng1982.springai.dashscope.api.DashscopeModelName;
810
import java.util.List;
911
import org.springframework.ai.document.Document;
@@ -12,6 +14,9 @@
1214
import org.springframework.ai.embedding.EmbeddingRequest;
1315
import org.springframework.ai.embedding.EmbeddingResponse;
1416

17+
/**
18+
* Spring AI {@linkplain EmbeddingClient} for Aliyun Dashscope
19+
*/
1520
public class DashscopeEmbeddingClient implements EmbeddingClient {
1621

1722
@Override
@@ -32,11 +37,12 @@ public EmbeddingResponse call(EmbeddingRequest request) {
3237
TextEmbedding embedding = new TextEmbedding();
3338
try {
3439
var result = embedding.call(builder.build());
35-
return new EmbeddingResponse(result.getOutput().getEmbeddings().stream().map(item ->
36-
new Embedding(item.getEmbedding(), item.getTextIndex())
37-
).toList());
38-
} catch (NoApiKeyException e) {
39-
throw new RuntimeException(e);
40+
return new EmbeddingResponse(
41+
result.getOutput().getEmbeddings().stream().map(item ->
42+
new Embedding(item.getEmbedding(), item.getTextIndex())
43+
).toList());
44+
} catch (ApiException | NoApiKeyException e) {
45+
throw new DashscopeApiException(e);
4046
}
4147
}
4248

src/main/java/io/github/alexcheng1982/springai/dashscope/api/DashscopeApi.java

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import com.alibaba.dashscope.aigc.generation.GenerationParam;
77
import com.alibaba.dashscope.aigc.generation.GenerationResult;
88
import com.alibaba.dashscope.common.Message;
9+
import com.alibaba.dashscope.exception.ApiException;
910
import com.alibaba.dashscope.exception.InputRequiredException;
1011
import com.alibaba.dashscope.exception.NoApiKeyException;
1112
import io.github.alexcheng1982.springai.dashscope.DashscopeChatOptions;
13+
import io.reactivex.Flowable;
1214
import java.util.List;
1315
import java.util.Objects;
1416
import java.util.Optional;
@@ -27,6 +29,25 @@ public DashscopeApi() {
2729
public GenerationResult chatCompletion(
2830
List<Message> messages,
2931
DashscopeChatOptions options) {
32+
try {
33+
return generation.call(buildGenerationParam(messages, options, false));
34+
} catch (ApiException | NoApiKeyException | InputRequiredException e) {
35+
throw new DashscopeApiException(e);
36+
}
37+
}
38+
39+
public Flowable<GenerationResult> chatCompletionStream(List<Message> messages,
40+
DashscopeChatOptions options) {
41+
try {
42+
return generation.streamCall(
43+
buildGenerationParam(messages, options, true));
44+
} catch (ApiException | NoApiKeyException | InputRequiredException e) {
45+
throw new DashscopeApiException(e);
46+
}
47+
}
48+
49+
private GenerationParam buildGenerationParam(List<Message> messages,
50+
DashscopeChatOptions options, boolean streaming) {
3051
var builder = GenerationParam.builder()
3152
.model(options.getModel())
3253
.topP(Optional.ofNullable(options.getTopP()).map(Double::valueOf)
@@ -39,20 +60,15 @@ public GenerationResult chatCompletion(
3960
.maxTokens(options.getMaxTokens())
4061
.messages(messages)
4162
.tools(options.getTools())
42-
.resultFormat(MESSAGE);
63+
.resultFormat(MESSAGE)
64+
.incrementalOutput(streaming);
4365

4466
if (options.getStops() != null) {
4567
builder.stopStrings(options.getStops());
4668
}
47-
48-
try {
49-
return generation.call(builder.build());
50-
} catch (NoApiKeyException | InputRequiredException e) {
51-
throw new RuntimeException(e);
52-
}
69+
return builder.build();
5370
}
5471

55-
5672
public record ChatCompletionRequest(
5773
List<Message> messages,
5874
DashscopeChatOptions options) {
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package io.github.alexcheng1982.springai.dashscope.api;
2+
3+
public class DashscopeApiException extends RuntimeException {
4+
5+
public DashscopeApiException(Throwable cause) {
6+
super(cause);
7+
}
8+
}
Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,6 @@
11
{
2-
"resources":{
3-
"includes":[{
4-
"pattern":"\\QMETA-INF/services/java.lang.System$LoggerFinder\\E"
5-
}, {
6-
"pattern":"\\QMETA-INF/services/java.net.spi.InetAddressResolverProvider\\E"
7-
}, {
8-
"pattern":"\\QMETA-INF/services/java.net.spi.URLStreamHandlerProvider\\E"
9-
}, {
10-
"pattern":"\\QMETA-INF/services/java.time.zone.ZoneRulesProvider\\E"
11-
}, {
12-
"pattern":"\\QMETA-INF/services/org.slf4j.spi.SLF4JServiceProvider\\E"
13-
}, {
14-
"pattern":"\\Qsimplelogger.properties\\E"
15-
}, {
16-
"pattern":"java.base:\\Qjdk/internal/icu/impl/data/icudt72b/nfkc.nrm\\E"
17-
}, {
18-
"pattern":"java.base:\\Qjdk/internal/icu/impl/data/icudt72b/uprops.icu\\E"
19-
}, {
20-
"pattern":"java.base:\\Qsun/net/idn/uidna.spp\\E"
21-
}]},
22-
"bundles":[]
2+
"resources": {
3+
"includes": []
4+
},
5+
"bundles": []
236
}

src/test/java/io/github/alexcheng1982/springai/dashscope/DashscopeChatClientTest.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,11 @@ void smokeTest() {
1616
var response = client.call("hello");
1717
assertNotNull(response);
1818
}
19+
20+
@Test
21+
void streamSmokeTest() {
22+
var client = DashscopeChatClient.createDefault();
23+
var response = client.stream("如何做西红柿炖牛腩?");
24+
response.toIterable().forEach(System.out::println);
25+
}
1926
}

0 commit comments

Comments
 (0)