Skip to content

Commit c172449

Browse files
songguocolakevinlin09
authored andcommitted
add: add new websocket client for audio models with thread pool fix
1 parent f0a2139 commit c172449

File tree

5 files changed

+222
-8
lines changed

5 files changed

+222
-8
lines changed

src/main/java/com/alibaba/dashscope/multimodal/MultiModalDialog.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import com.alibaba.dashscope.exception.InputRequiredException;
88
import com.alibaba.dashscope.exception.NoApiKeyException;
99
import com.alibaba.dashscope.protocol.ApiServiceOption;
10+
import com.alibaba.dashscope.protocol.ConnectionOptions;
1011
import com.alibaba.dashscope.protocol.Protocol;
1112
import com.alibaba.dashscope.protocol.StreamingMode;
1213
import com.alibaba.dashscope.utils.Constants;
@@ -44,6 +45,8 @@ public class MultiModalDialog {
4445

4546
private ApiServiceOption serviceOption; // Service option configuration
4647

48+
private ConnectionOptions connectionOptions;
49+
4750
private Emitter<Object> conversationEmitter; // Message emitter
4851

4952
private MultiModalRequestParam requestParam; // Request parameter
@@ -137,7 +140,36 @@ public MultiModalDialog(
137140

138141
this.requestParam = param;
139142
this.callback = callback;
140-
this.duplexApi = new SynchronizeFullDuplexApi<>(serviceOption);
143+
connectionOptions = ConnectionOptions.builder().build();
144+
this.connectionOptions.setUseDefaultClient(false);
145+
this.duplexApi = new SynchronizeFullDuplexApi<>(this.connectionOptions,serviceOption);
146+
}
147+
148+
149+
/**
150+
* Constructor initializes service options and creates a duplex communication API instance.
151+
*
152+
* param: param Request parameter
153+
* param: callback Callback interface
154+
* param: connectionOptions Connection options
155+
*/
156+
public MultiModalDialog(
157+
MultiModalRequestParam param, MultiModalDialogCallback callback, ConnectionOptions connectionOptions) {
158+
this.serviceOption =
159+
ApiServiceOption.builder()
160+
.protocol(Protocol.WEBSOCKET)
161+
.streamingMode(StreamingMode.DUPLEX)
162+
.outputMode(OutputMode.ACCUMULATE)
163+
.taskGroup(TaskGroup.AIGC.getValue())
164+
.task(Task.MULTIMODAL_GENERATION.getValue())
165+
.function(Function.GENERATION.getValue())
166+
.build();
167+
this.connectionOptions = connectionOptions;
168+
this.connectionOptions.setUseDefaultClient(false);
169+
170+
this.requestParam = param;
171+
this.callback = callback;
172+
this.duplexApi = new SynchronizeFullDuplexApi<>(this.connectionOptions,serviceOption);
141173
}
142174

143175
/**

src/main/java/com/alibaba/dashscope/protocol/ClientProviders.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import com.alibaba.dashscope.protocol.okhttp.OkHttpClientFactory;
66
import com.alibaba.dashscope.protocol.okhttp.OkHttpHttpClient;
77
import com.alibaba.dashscope.protocol.okhttp.OkHttpWebSocketClient;
8+
import com.alibaba.dashscope.protocol.okhttp.OkHttpWebSocketClientForAudio;
89

910
public class ClientProviders {
1011
public static HalfDuplexClient getHalfDuplexClient(String protocol) {
@@ -54,8 +55,14 @@ public static FullDuplexClient getFullDuplexClient(
5455
// create default config client, create default http client.
5556
return new OkHttpWebSocketClient(OkHttpClientFactory.getOkHttpClient(), passTaskStarted);
5657
} else {
57-
return new OkHttpWebSocketClient(
58-
OkHttpClientFactory.getNewOkHttpClient(connectionOptions), passTaskStarted);
58+
if (connectionOptions.isUseDefaultClient()) {
59+
return new OkHttpWebSocketClient(
60+
OkHttpClientFactory.getNewOkHttpClient(connectionOptions), passTaskStarted);
61+
}else {
62+
// create custom client for audio models
63+
return new OkHttpWebSocketClientForAudio(
64+
OkHttpClientFactory.getNewOkHttpClient(connectionOptions), passTaskStarted);
65+
}
5966
}
6067
}
6168
}

src/main/java/com/alibaba/dashscope/protocol/ConnectionOptions.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ public final class ConnectionOptions {
3131
private Duration connectTimeout;
3232
private Duration writeTimeout;
3333
private Duration readTimeout;
34+
private boolean useDefaultClient = true;
3435

3536
public Duration getConnectTimeout() {
3637
return getDuration(connectTimeout, DEFAULT_CONNECT_TIMEOUT, CONNECTION_TIMEOUT_ENV);
@@ -84,4 +85,13 @@ public Proxy getProxy() {
8485
}
8586
return null;
8687
}
88+
89+
public boolean isUseDefaultClient() {
90+
return useDefaultClient;
91+
}
92+
93+
public void setUseDefaultClient(boolean useDefaultClient) {
94+
this.useDefaultClient = useDefaultClient;
95+
}
96+
8797
}

src/main/java/com/alibaba/dashscope/protocol/okhttp/OkHttpWebSocketClient.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ public class OkHttpWebSocketClient extends WebSocketListener
4141
private AtomicBoolean isOpen = new AtomicBoolean(false);
4242
private AtomicBoolean isClosed = new AtomicBoolean(false);
4343
// indicate the first response is received.
44-
private AtomicBoolean isFirstMessage = new AtomicBoolean(false);
44+
protected AtomicBoolean isFirstMessage = new AtomicBoolean(false);
4545
// used for get request response
46-
private FlowableEmitter<DashScopeResult> responseEmitter;
46+
protected FlowableEmitter<DashScopeResult> responseEmitter;
4747
// is the result is flatten format.
4848
private boolean isFlattenResult;
4949
private FlowableEmitter<DashScopeResult> connectionEmitter;
@@ -363,7 +363,7 @@ public void onOpen(WebSocket webSocket, Response response) {
363363
}
364364
}
365365

366-
private void sendTextWithRetry(
366+
protected void sendTextWithRetry(
367367
String apiKey,
368368
boolean isSecurityCheck,
369369
String message,
@@ -402,7 +402,7 @@ private void sendTextWithRetry(
402402
}
403403
}
404404

405-
private void sendBinaryWithRetry(
405+
protected void sendBinaryWithRetry(
406406
String apiKey,
407407
boolean isSecurityCheck,
408408
ByteString message,
@@ -555,7 +555,7 @@ public void run() throws Exception {
555555
});
556556
}
557557

558-
private CompletableFuture<Void> sendStreamRequest(FullDuplexRequest req) {
558+
protected CompletableFuture<Void> sendStreamRequest(FullDuplexRequest req) {
559559
CompletableFuture<Void> future =
560560
CompletableFuture.runAsync(
561561
() -> {
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
package com.alibaba.dashscope.protocol.okhttp;
2+
3+
import com.alibaba.dashscope.protocol.FullDuplexRequest;
4+
import com.alibaba.dashscope.utils.JsonUtils;
5+
import com.google.gson.JsonObject;
6+
import io.reactivex.Flowable;
7+
import io.reactivex.functions.Action;
8+
import lombok.extern.slf4j.Slf4j;
9+
import okhttp3.OkHttpClient;
10+
import okio.ByteString;
11+
12+
import java.nio.ByteBuffer;
13+
import java.util.concurrent.*;
14+
import java.util.concurrent.atomic.AtomicBoolean;
15+
import java.util.concurrent.atomic.AtomicInteger;
16+
17+
/**
18+
* @author songsong.shao
19+
* @date 2025/11/5
20+
*/
21+
@Slf4j
22+
public class OkHttpWebSocketClientForAudio extends OkHttpWebSocketClient {
23+
24+
private static final AtomicInteger STREAMING_REQUEST_THREAD_NUM = new AtomicInteger(0);
25+
private static final AtomicBoolean SHUTDOWN_INITIATED = new AtomicBoolean(false);
26+
27+
private static final ExecutorService STREAMING_REQUEST_EXECUTOR =
28+
new ThreadPoolExecutor(1, 100, 60L, TimeUnit.SECONDS, new SynchronousQueue<>(), r -> {
29+
Thread t = new Thread(r, "WS-STREAMING-REQ-Worker-" + STREAMING_REQUEST_THREAD_NUM.updateAndGet(n -> n == Integer.MAX_VALUE ? 0 : n + 1));
30+
t.setDaemon(true);
31+
return t;
32+
});
33+
34+
public OkHttpWebSocketClientForAudio(OkHttpClient client, boolean passTaskStarted) {
35+
super(client, passTaskStarted);
36+
log.info("Use OkHttpWebSocketClientForAudio");
37+
}
38+
39+
@Override
40+
protected CompletableFuture<Void> sendStreamRequest(FullDuplexRequest req) {
41+
CompletableFuture<Void> future =
42+
CompletableFuture.runAsync(
43+
() -> {
44+
try {
45+
isFirstMessage.set(false);
46+
47+
JsonObject startMessage = req.getStartTaskMessage();
48+
log.info("send run-task request {}", JsonUtils.toJson(startMessage));
49+
String taskId =
50+
startMessage.get("header").getAsJsonObject().get("task_id").getAsString();
51+
// send start message out.
52+
sendTextWithRetry(
53+
req.getApiKey(),
54+
req.isSecurityCheck(),
55+
JsonUtils.toJson(startMessage),
56+
req.getWorkspace(),
57+
req.getHeaders(),
58+
req.getBaseWebSocketUrl());
59+
60+
Flowable<Object> streamingData = req.getStreamingData();
61+
streamingData.subscribe(
62+
data -> {
63+
try {
64+
if (data instanceof String) {
65+
JsonObject continueData = req.getContinueMessage((String) data, taskId);
66+
sendTextWithRetry(
67+
req.getApiKey(),
68+
req.isSecurityCheck(),
69+
JsonUtils.toJson(continueData),
70+
req.getWorkspace(),
71+
req.getHeaders(),
72+
req.getBaseWebSocketUrl());
73+
} else if (data instanceof byte[]) {
74+
sendBinaryWithRetry(
75+
req.getApiKey(),
76+
req.isSecurityCheck(),
77+
ByteString.of((byte[]) data),
78+
req.getWorkspace(),
79+
req.getHeaders(),
80+
req.getBaseWebSocketUrl());
81+
} else if (data instanceof ByteBuffer) {
82+
sendBinaryWithRetry(
83+
req.getApiKey(),
84+
req.isSecurityCheck(),
85+
ByteString.of((ByteBuffer) data),
86+
req.getWorkspace(),
87+
req.getHeaders(),
88+
req.getBaseWebSocketUrl());
89+
} else {
90+
JsonObject continueData = req.getContinueMessage(data, taskId);
91+
sendTextWithRetry(
92+
req.getApiKey(),
93+
req.isSecurityCheck(),
94+
JsonUtils.toJson(continueData),
95+
req.getWorkspace(),
96+
req.getHeaders(),
97+
req.getBaseWebSocketUrl());
98+
}
99+
} catch (Throwable ex) {
100+
log.error(String.format("sendStreamData exception: %s", ex.getMessage()));
101+
responseEmitter.onError(ex);
102+
}
103+
},
104+
err -> {
105+
log.error(String.format("Get stream data error!"));
106+
responseEmitter.onError(err);
107+
},
108+
new Action() {
109+
@Override
110+
public void run() throws Exception {
111+
log.debug(String.format("Stream data send completed!"));
112+
sendTextWithRetry(
113+
req.getApiKey(),
114+
req.isSecurityCheck(),
115+
JsonUtils.toJson(req.getFinishedTaskMessage(taskId)),
116+
req.getWorkspace(),
117+
req.getHeaders(),
118+
req.getBaseWebSocketUrl());
119+
}
120+
});
121+
} catch (Throwable ex) {
122+
log.error(String.format("sendStreamData exception: %s", ex.getMessage()));
123+
responseEmitter.onError(ex);
124+
}
125+
});
126+
return future;
127+
}
128+
129+
static {//auto close when jvm shutdown
130+
Runtime.getRuntime().addShutdownHook(new Thread(OkHttpWebSocketClientForAudio::shutdownStreamingExecutor));
131+
}
132+
/**
133+
* Shutdown the streaming request executor gracefully.
134+
* This method should be called when the application is shutting down
135+
* to ensure proper resource cleanup.
136+
*/
137+
private static void shutdownStreamingExecutor() {
138+
if (!SHUTDOWN_INITIATED.compareAndSet(false, true)) {
139+
log.debug("Shutdown already in progress");
140+
return;
141+
}
142+
143+
if (!STREAMING_REQUEST_EXECUTOR.isShutdown()) {
144+
log.debug("Shutting down streaming request executor...");
145+
STREAMING_REQUEST_EXECUTOR.shutdown();
146+
try {
147+
// Wait up to 60 seconds for existing tasks to terminate
148+
if (!STREAMING_REQUEST_EXECUTOR.awaitTermination(60, TimeUnit.SECONDS)) {
149+
log.warn("Streaming request executor did not terminate in 60 seconds, forcing shutdown...");
150+
STREAMING_REQUEST_EXECUTOR.shutdownNow();
151+
// Wait up to 60 seconds for tasks to respond to being cancelled
152+
if (!STREAMING_REQUEST_EXECUTOR.awaitTermination(60, TimeUnit.SECONDS)) {
153+
log.error("Streaming request executor did not terminate");
154+
}
155+
}
156+
} catch (InterruptedException ie) {
157+
// (Re-)Cancel if current thread also interrupted
158+
STREAMING_REQUEST_EXECUTOR.shutdownNow();
159+
// Preserve interrupt status
160+
Thread.currentThread().interrupt();
161+
}
162+
log.info("Streaming request executor shut down completed");
163+
}
164+
}
165+
}

0 commit comments

Comments
 (0)