Skip to content

Commit a22425a

Browse files
authored
DashScope: support partial tool call and reasoning (#308)
1 parent 2e80983 commit a22425a

File tree

11 files changed

+467
-20
lines changed

11 files changed

+467
-20
lines changed

models/langchain4j-community-dashscope/src/main/java/dev/langchain4j/community/model/dashscope/QwenChatResponseMetadata.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import static dev.langchain4j.internal.Utils.quoted;
44

55
import dev.langchain4j.Experimental;
6+
import dev.langchain4j.data.message.AiMessage;
67
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
78
import java.util.HashMap;
89
import java.util.List;
@@ -12,6 +13,10 @@
1213
@Experimental
1314
public class QwenChatResponseMetadata extends ChatResponseMetadata {
1415
private final SearchInfo searchInfo;
16+
/**
17+
* @deprecated Please use {@link AiMessage#thinking} instead.
18+
*/
19+
@Deprecated(since = "1.2.0", forRemoval = true)
1520
private final String reasoningContent;
1621

1722
protected QwenChatResponseMetadata(Builder builder) {
@@ -24,6 +29,10 @@ public SearchInfo searchInfo() {
2429
return searchInfo;
2530
}
2631

32+
/**
33+
* @deprecated Please use {@link AiMessage#thinking()} instead.
34+
*/
35+
@Deprecated(since = "1.2.0", forRemoval = true)
2736
public String reasoningContent() {
2837
return reasoningContent;
2938
}
@@ -93,6 +102,10 @@ public Builder searchInfo(SearchInfo searchInfo) {
93102
return this;
94103
}
95104

105+
/**
106+
* @deprecated Please use {@link AiMessage#thinking()} instead.
107+
*/
108+
@Deprecated(since = "1.2.0", forRemoval = true)
96109
public Builder reasoningContent(String reasoningContent) {
97110
this.reasoningContent = reasoningContent;
98111
return this;
@@ -133,12 +146,12 @@ public SearchInfo build() {
133146
* Results from online searches.
134147
*
135148
* @param siteName the name of the website from which the search results came
136-
* @param icon the URL of the icon from the source website, or an empty string if there is
137-
* no icon
138-
* @param index the sequence number of the search result, indicating the index of the
139-
* search result in search_results
140-
* @param title the title of the search result
141-
* @param url the URL of the search result
149+
* @param icon the URL of the icon from the source website, or an empty string if there is
150+
* no icon
151+
* @param index the sequence number of the search result, indicating the index of the
152+
* search result in search_results
153+
* @param title the title of the search result
154+
* @param url the URL of the search result
142155
*/
143156
public record SearchResult(String siteName, String icon, Integer index, String title, String url) {
144157
public static Builder builder() {

models/langchain4j-community-dashscope/src/main/java/dev/langchain4j/community/model/dashscope/QwenHelper.java

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import dev.langchain4j.model.output.FinishReason;
6262
import dev.langchain4j.model.output.Response;
6363
import dev.langchain4j.model.output.TokenUsage;
64+
import java.util.ArrayList;
6465
import java.util.Collection;
6566
import java.util.Collections;
6667
import java.util.HashMap;
@@ -236,6 +237,18 @@ static boolean hasAnswer(GenerationResult result) {
236237
.isPresent();
237238
}
238239

240+
static boolean hasReasoningContent(GenerationResult result) {
241+
return Optional.of(result)
242+
.map(GenerationResult::getOutput)
243+
.map(GenerationOutput::getChoices)
244+
.filter(choices -> !choices.isEmpty())
245+
.map(choices -> choices.get(0))
246+
.map(Choice::getMessage)
247+
.map(Message::getReasoningContent)
248+
.filter(Utils::isNotNullOrEmpty)
249+
.isPresent();
250+
}
251+
239252
static String answerFrom(GenerationResult result) {
240253
return Optional.of(result)
241254
.map(GenerationResult::getOutput)
@@ -423,14 +436,16 @@ static ChatResponse chatResponseFrom(String modelName, GenerationResult result)
423436
}
424437

425438
static AiMessage aiMessageFrom(GenerationResult result) {
439+
String text = answerFrom(result);
440+
String reasoningContentFrom = reasoningContentFrom(result);
441+
AiMessage.Builder aiMessageBuilder = AiMessage.builder()
442+
.text(isNullOrBlank(text) ? null : text)
443+
.thinking(isNullOrBlank(reasoningContentFrom) ? null : reasoningContentFrom);
426444
if (isFunctionToolCalls(result)) {
427-
String text = answerFrom(result);
428-
return isNullOrBlank(text)
429-
? new AiMessage(toolExecutionRequestsFrom(result))
430-
: new AiMessage(text, toolExecutionRequestsFrom(result));
431-
} else {
432-
return new AiMessage(answerFrom(result));
445+
aiMessageBuilder = aiMessageBuilder.toolExecutionRequests(toolExecutionRequestsFrom(result));
433446
}
447+
448+
return aiMessageBuilder.build();
434449
}
435450

436451
private static List<ToolExecutionRequest> toolExecutionRequestsFrom(GenerationResult result) {
@@ -445,6 +460,22 @@ private static List<ToolExecutionRequest> toolExecutionRequestsFrom(GenerationRe
445460
.collect(toList());
446461
}
447462

463+
static List<ToolCallFunction> toolCallFunctionsFrom(GenerationResult result) {
464+
List<ToolCallBase> toolCalls = Optional.of(result)
465+
.map(GenerationResult::getOutput)
466+
.map(GenerationOutput::getChoices)
467+
.filter(choices -> !choices.isEmpty())
468+
.map(choices -> choices.get(0))
469+
.map(Choice::getMessage)
470+
.map(Message::getToolCalls)
471+
.orElse(new ArrayList<>());
472+
473+
return toolCalls.stream()
474+
.filter(ToolCallFunction.class::isInstance)
475+
.map(ToolCallFunction.class::cast)
476+
.collect(toList());
477+
}
478+
448479
static List<ToolCallBase> toolCallsFrom(GenerationResult result) {
449480
return Optional.of(result)
450481
.map(GenerationResult::getOutput)
@@ -488,6 +519,7 @@ static ChatResponse chatResponseFrom(String modelName, MultiModalConversationRes
488519
.modelName(modelName)
489520
.tokenUsage(tokenUsageFrom(result))
490521
.finishReason(finishReasonFrom(result))
522+
// deprecated
491523
.reasoningContent(reasoningContentFrom(result))
492524
.build())
493525
.build();
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package dev.langchain4j.community.model.dashscope;
2+
3+
import dev.langchain4j.model.chat.response.CompleteToolCall;
4+
import dev.langchain4j.model.chat.response.PartialThinking;
5+
import dev.langchain4j.model.chat.response.PartialToolCall;
6+
import java.util.List;
7+
import java.util.Objects;
8+
9+
public class QwenPartialResponse {
10+
11+
private final String delta;
12+
private final PartialThinking partialThinking;
13+
private final List<PartialToolCall> partialToolCalls;
14+
private final List<CompleteToolCall> completeToolCalls;
15+
16+
private QwenPartialResponse(Builder builder) {
17+
this.delta = builder.delta;
18+
this.partialThinking = builder.partialThinking;
19+
this.partialToolCalls = builder.partialToolCalls;
20+
this.completeToolCalls = builder.completeToolCalls;
21+
}
22+
23+
public String delta() {
24+
return delta;
25+
}
26+
27+
public PartialThinking partialThinking() {
28+
return partialThinking;
29+
}
30+
31+
public List<PartialToolCall> partialToolCalls() {
32+
return partialToolCalls;
33+
}
34+
35+
public List<CompleteToolCall> completeToolCalls() {
36+
return completeToolCalls;
37+
}
38+
39+
@Override
40+
public boolean equals(final Object o) {
41+
if (o == null || getClass() != o.getClass()) return false;
42+
QwenPartialResponse that = (QwenPartialResponse) o;
43+
return Objects.equals(delta, that.delta)
44+
&& Objects.equals(partialThinking, that.partialThinking)
45+
&& Objects.equals(partialToolCalls, that.partialToolCalls)
46+
&& Objects.equals(completeToolCalls, that.completeToolCalls);
47+
}
48+
49+
@Override
50+
public int hashCode() {
51+
return Objects.hash(delta, partialThinking, partialToolCalls, completeToolCalls);
52+
}
53+
54+
@Override
55+
public String toString() {
56+
return "QwenPartialResponse{" + "delta='"
57+
+ delta + '\'' + ", partialThinking="
58+
+ partialThinking + ", partialToolCalls="
59+
+ partialToolCalls + ", completeToolCalls="
60+
+ completeToolCalls + '}';
61+
}
62+
63+
static Builder builder() {
64+
return new Builder();
65+
}
66+
67+
public static class Builder {
68+
69+
private String delta;
70+
private PartialThinking partialThinking;
71+
private List<PartialToolCall> partialToolCalls;
72+
private List<CompleteToolCall> completeToolCalls;
73+
74+
public Builder delta(String delta) {
75+
this.delta = delta;
76+
return this;
77+
}
78+
79+
public Builder partialThinking(PartialThinking partialThinking) {
80+
this.partialThinking = partialThinking;
81+
return this;
82+
}
83+
84+
public Builder partialToolCalls(List<PartialToolCall> partialToolCalls) {
85+
this.partialToolCalls = partialToolCalls;
86+
return this;
87+
}
88+
89+
public Builder completeToolCalls(List<CompleteToolCall> completeToolCalls) {
90+
this.completeToolCalls = completeToolCalls;
91+
return this;
92+
}
93+
94+
public QwenPartialResponse build() {
95+
return new QwenPartialResponse(this);
96+
}
97+
}
98+
}

models/langchain4j-community-dashscope/src/main/java/dev/langchain4j/community/model/dashscope/QwenStreamingChatModel.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import static dev.langchain4j.internal.Utils.getOrDefault;
1111
import static dev.langchain4j.internal.Utils.isNotNullOrEmpty;
1212
import static dev.langchain4j.internal.Utils.isNullOrBlank;
13+
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
1314
import static dev.langchain4j.internal.Utils.quoted;
1415
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
1516
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
@@ -36,6 +37,8 @@
3637
import dev.langchain4j.model.chat.request.ChatRequest;
3738
import dev.langchain4j.model.chat.request.ChatRequestParameters;
3839
import dev.langchain4j.model.chat.request.DefaultChatRequestParameters;
40+
import dev.langchain4j.model.chat.response.CompleteToolCall;
41+
import dev.langchain4j.model.chat.response.PartialToolCall;
3942
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
4043
import java.util.ArrayList;
4144
import java.util.List;
@@ -146,9 +149,24 @@ private void generateByNonMultimodalModel(ChatRequest chatRequest, StreamingChat
146149
@Override
147150
public void onEvent(GenerationResult result) {
148151
try {
149-
String delta = responseBuilder.append(result);
150-
if (isNotNullOrEmpty(delta)) {
151-
handler.onPartialResponse(delta);
152+
QwenPartialResponse partialResponse = responseBuilder.append(result);
153+
if (isNotNullOrEmpty(partialResponse.delta())) {
154+
handler.onPartialResponse(partialResponse.delta());
155+
}
156+
if (partialResponse.partialThinking() != null) {
157+
handler.onPartialThinking(partialResponse.partialThinking());
158+
}
159+
List<PartialToolCall> partialToolCalls = partialResponse.partialToolCalls();
160+
if (!isNullOrEmpty(partialToolCalls)) {
161+
for (PartialToolCall toolCall : partialToolCalls) {
162+
handler.onPartialToolCall(toolCall);
163+
}
164+
}
165+
List<CompleteToolCall> completeToolCalls = partialResponse.completeToolCalls();
166+
if (!isNullOrEmpty(completeToolCalls)) {
167+
for (CompleteToolCall toolCall : completeToolCalls) {
168+
handler.onCompleteToolCall(toolCall);
169+
}
152170
}
153171
} catch (Throwable t) {
154172
RuntimeException mappedException = ExceptionMapper.DEFAULT.mapException(t);
@@ -160,6 +178,10 @@ public void onEvent(GenerationResult result) {
160178
public void onComplete() {
161179
try {
162180
handler.onCompleteResponse(responseBuilder.build());
181+
CompleteToolCall completeToolCall = responseBuilder.buildCompleteToolCall();
182+
if (completeToolCall != null) {
183+
handler.onCompleteToolCall(completeToolCall);
184+
}
163185
} catch (Throwable t) {
164186
RuntimeException mappedException = ExceptionMapper.DEFAULT.mapException(t);
165187
withLoggingExceptions(() -> handler.onError(mappedException));

models/langchain4j-community-dashscope/src/main/java/dev/langchain4j/community/model/dashscope/QwenStreamingLanguageModel.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ public void generate(String prompt, StreamingResponseHandler<String> handler) {
109109
generation.streamCall(builder.build(), new ResultCallback<>() {
110110
@Override
111111
public void onEvent(GenerationResult result) {
112-
String delta = responseBuilder.append(result);
113-
if (Utils.isNotNullOrBlank(delta)) {
114-
handler.onNext(delta);
112+
QwenPartialResponse partialResponse = responseBuilder.append(result);
113+
if (Utils.isNotNullOrBlank(partialResponse.delta())) {
114+
handler.onNext(partialResponse.delta());
115115
}
116116
}
117117

0 commit comments

Comments
 (0)