Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,23 @@ private Flux<ChatResponse> streamWithHttpClient(

if (stream) {
// Streaming mode
return httpClient.stream(request)
return httpClient.stream(
request,
effectiveOptions.getAdditionalHeaders(),
effectiveOptions.getAdditionalBodyParams(),
effectiveOptions.getAdditionalQueryParams())
.map(response -> formatter.parseResponse(response, start));
} else {
// Non-streaming mode
return Flux.defer(
() -> {
try {
DashScopeResponse response = httpClient.call(request);
DashScopeResponse response =
httpClient.call(
request,
effectiveOptions.getAdditionalHeaders(),
effectiveOptions.getAdditionalBodyParams(),
effectiveOptions.getAdditionalQueryParams());
ChatResponse chatResponse = formatter.parseResponse(response, start);
return Flux.just(chatResponse);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,17 +109,6 @@ public DashScopeHttpClient(String apiKey) {
this(apiKey, null);
}

/**
* Make a synchronous API call.
*
* @param request the DashScope request
* @return the DashScope response
* @throws DashScopeHttpException if the request fails
*/
public DashScopeResponse call(DashScopeRequest request) {
return call(request, null, null, null);
}

/**
* Make a synchronous API call with additional HTTP parameters.
*
Expand Down Expand Up @@ -180,16 +169,6 @@ public DashScopeResponse call(
}
}

/**
* Make a streaming API call.
*
* @param request the DashScope request
* @return a Flux of DashScope responses (one per SSE event)
*/
public Flux<DashScopeResponse> stream(DashScopeRequest request) {
return stream(request, null, null, null);
}

/**
* Make a streaming API call with additional HTTP parameters.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,18 @@
import io.agentscope.core.formatter.dashscope.dto.DashScopeParameters;
import io.agentscope.core.formatter.dashscope.dto.DashScopeRequest;
import io.agentscope.core.message.Msg;
import io.agentscope.core.message.MsgRole;
import io.agentscope.core.message.TextBlock;
import io.agentscope.core.model.test.ModelTestUtils;
import io.agentscope.core.model.transport.OkHttpTransport;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Tag;
Expand Down Expand Up @@ -420,6 +427,95 @@ void testWithThinkingMode() {
assertNotNull(thinkingModel);
}

@Test
@DisplayName("DashScope chat model stream with additional headers and params")
void testDoStreamWithAdditionHeadersAndParams() throws Exception {
MockWebServer mockServer = new MockWebServer();
mockServer.start();

mockServer.enqueue(
new MockResponse()
.setResponseCode(200)
.setBody("{\"request_id\":\"test\",\"output\":{\"choices\":[]}}")
.setHeader("Content-Type", "application/json"));

DashScopeChatModel chatModel =
DashScopeChatModel.builder().apiKey(mockApiKey).modelName("qwen-plus").stream(true)
.enableThinking(true)
.enableSearch(true)
.baseUrl(mockServer.url("/").toString().replaceAll("/$", ""))
.httpTransport(OkHttpTransport.builder().build())
.build();

chatModel
.doStream(
List.of(
Msg.builder()
.role(MsgRole.USER)
.content(TextBlock.builder().text("test").build())
.build()),
List.of(),
GenerateOptions.builder()
.additionalHeaders(Map.of("custom", "custom-header"))
.additionalBodyParams(Map.of("custom", "custom-body"))
.additionalQueryParams(Map.of("custom", "custom-query"))
.build())
.blockLast();

RecordedRequest recorded = mockServer.takeRequest();
assertEquals("custom-header", recorded.getHeader("custom"));
assertEquals(
DashScopeHttpClient.TEXT_GENERATION_ENDPOINT + "?custom=custom-query",
recorded.getPath());
assertTrue(recorded.getBody().readUtf8().contains("\"custom\":\"custom-body\""));

mockServer.close();
}

@Test
@DisplayName("DashScope chat model non-stream with additional headers and params")
void testDoNonStreamWithAdditionHeadersAndParams() throws Exception {
MockWebServer mockServer = new MockWebServer();
mockServer.start();

mockServer.enqueue(
new MockResponse()
.setResponseCode(200)
.setBody("{\"request_id\":\"test\",\"output\":{\"choices\":[]}}")
.setHeader("Content-Type", "application/json"));

DashScopeChatModel chatModel =
DashScopeChatModel.builder().apiKey(mockApiKey).modelName("qwen-plus").stream(true)
.stream(false)
.baseUrl(mockServer.url("/").toString().replaceAll("/$", ""))
.httpTransport(OkHttpTransport.builder().build())
.build();

chatModel
.doStream(
List.of(
Msg.builder()
.role(MsgRole.USER)
.content(TextBlock.builder().text("test").build())
.build()),
List.of(),
GenerateOptions.builder()
.additionalHeaders(Map.of("custom", "custom-header"))
.additionalBodyParams(Map.of("custom", "custom-body"))
.additionalQueryParams(Map.of("custom", "custom-query"))
.build())
.blockLast();

RecordedRequest recorded = mockServer.takeRequest();
assertEquals("custom-header", recorded.getHeader("custom"));
assertEquals(
DashScopeHttpClient.TEXT_GENERATION_ENDPOINT + "?custom=custom-query",
recorded.getPath());
assertTrue(recorded.getBody().readUtf8().contains("\"custom\":\"custom-body\""));

mockServer.close();
}

@Test
@DisplayName("DashScope chat model apply thinking mode")
void testApplyThinkingMode() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import io.agentscope.core.formatter.dashscope.dto.DashScopeRequest;
import io.agentscope.core.formatter.dashscope.dto.DashScopeResponse;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
Expand Down Expand Up @@ -137,7 +139,7 @@ void testCallTextGenerationApi() throws Exception {

DashScopeRequest request = createTestRequest("qwen-plus", "Hello");

DashScopeResponse response = client.call(request);
DashScopeResponse response = client.call(request, null, null, null);

assertNotNull(response);
assertEquals("test-request-id", response.getRequestId());
Expand Down Expand Up @@ -181,7 +183,7 @@ void testCallMultimodalApi() throws Exception {

DashScopeRequest request = createTestRequest("qwen-vl-max", "What's in this image?");

DashScopeResponse response = client.call(request);
DashScopeResponse response = client.call(request, null, null, null);

assertNotNull(response);
assertEquals("multimodal-request-id", response.getRequestId());
Expand Down Expand Up @@ -209,7 +211,7 @@ void testStreamTextGenerationApi() {
DashScopeRequest request = createTestRequest("qwen-plus", "Hi");

List<DashScopeResponse> responses = new ArrayList<>();
StepVerifier.create(client.stream(request))
StepVerifier.create(client.stream(request, null, null, null))
.recordWith(() -> responses)
.expectNextCount(2)
.verifyComplete();
Expand Down Expand Up @@ -239,7 +241,7 @@ void testStreamMultimodalApi() {

DashScopeRequest request = createTestRequest("qwen-vl-max", "Describe image");

StepVerifier.create(client.stream(request))
StepVerifier.create(client.stream(request, null, null, null))
.expectNextMatches(
r ->
"I see"
Expand Down Expand Up @@ -273,7 +275,7 @@ void testApiErrorHandling() {
DashScopeHttpClient.DashScopeHttpException exception =
assertThrows(
DashScopeHttpClient.DashScopeHttpException.class,
() -> client.call(request));
() -> client.call(request, null, null, null));

assertTrue(exception.getMessage().contains("Invalid API key"));
}
Expand All @@ -291,7 +293,7 @@ void testHttpErrorHandling() {
DashScopeHttpClient.DashScopeHttpException exception =
assertThrows(
DashScopeHttpClient.DashScopeHttpException.class,
() -> client.call(request));
() -> client.call(request, null, null, null));

assertEquals(500, exception.getStatusCode());
}
Expand Down Expand Up @@ -322,7 +324,7 @@ void testRequestHeaders() throws Exception {
.setHeader("Content-Type", "application/json"));

DashScopeRequest request = createTestRequest("qwen-plus", "test");
client.call(request);
client.call(request, null, null, null);

RecordedRequest recorded = mockServer.takeRequest();
assertEquals("Bearer test-api-key", recorded.getHeader("Authorization"));
Expand All @@ -339,7 +341,7 @@ void testStreamingRequestHeaders() throws Exception {
.setHeader("Content-Type", "text/event-stream"));

DashScopeRequest request = createTestRequest("qwen-plus", "test");
client.stream(request).blockLast();
client.stream(request, null, null, null).blockLast();

RecordedRequest recorded = mockServer.takeRequest();
assertEquals("enable", recorded.getHeader("X-DashScope-SSE"));
Expand Down Expand Up @@ -482,6 +484,61 @@ void testHeaderOverride() throws Exception {
assertEquals("application/json; charset=utf-8", recorded.getHeader("Content-Type"));
}

@Test
void testCallAdditionalHeadersAndParams() throws Exception {
mockServer.enqueue(
new MockResponse()
.setResponseCode(200)
.setBody("{\"request_id\":\"test\",\"output\":{\"choices\":[]}}")
.setHeader("Content-Type", "application/json"));

DashScopeRequest request = createTestRequest("qwen-plus", "test");
// Override the Content-Type header
Map<String, String> additionalHeaders = new HashMap<>();
additionalHeaders.put("custom", "custom-header");
Map<String, Object> additionalBodyParams = new HashMap<>();
additionalBodyParams.put("custom", "custom-body");
Map<String, String> additionalQueryParams = new HashMap<>();
additionalQueryParams.put("custom", "custom-query");

client.call(request, additionalHeaders, additionalBodyParams, additionalQueryParams);

RecordedRequest recorded = mockServer.takeRequest();
assertEquals("custom-header", recorded.getHeader("custom"));
assertEquals(
DashScopeHttpClient.TEXT_GENERATION_ENDPOINT + "?custom=custom-query",
recorded.getPath());
assertTrue(recorded.getBody().readUtf8().contains("\"custom\":\"custom-body\""));
}

@Test
void testStreamAdditionalHeadersAndParams() throws Exception {
mockServer.enqueue(
new MockResponse()
.setResponseCode(200)
.setBody("{\"request_id\":\"test\",\"output\":{\"choices\":[]}}")
.setHeader("Content-Type", "application/json"));

DashScopeRequest request = createTestRequest("qwen-plus", "test");
// Override the Content-Type header
Map<String, String> additionalHeaders = new HashMap<>();
additionalHeaders.put("custom", "custom-header");
Map<String, Object> additionalBodyParams = new HashMap<>();
additionalBodyParams.put("custom", "custom-value");
Map<String, String> additionalQueryParams = new HashMap<>();
additionalQueryParams.put("custom", "custom-value");

client.stream(request, additionalHeaders, additionalBodyParams, additionalQueryParams)
.blockLast();

RecordedRequest recorded = mockServer.takeRequest();
assertEquals("custom-header", recorded.getHeader("custom"));
assertEquals(
DashScopeHttpClient.TEXT_GENERATION_ENDPOINT + "?custom=custom-value",
recorded.getPath());
assertTrue(recorded.getBody().readUtf8().contains("\"custom\":\"custom-value\""));
}

// ==================== DashScopeHttpException Tests ====================

@Test
Expand Down
Loading