diff --git a/agentscope-core/src/main/java/io/agentscope/core/model/DashScopeChatModel.java b/agentscope-core/src/main/java/io/agentscope/core/model/DashScopeChatModel.java index 4b22d45df..8863030fa 100644 --- a/agentscope-core/src/main/java/io/agentscope/core/model/DashScopeChatModel.java +++ b/agentscope-core/src/main/java/io/agentscope/core/model/DashScopeChatModel.java @@ -230,14 +230,23 @@ private Flux streamWithHttpClient( if (stream) { // Streaming mode - return httpClient.stream(request) + return httpClient.stream( + request, + effectiveOptions.getAdditionalHeaders(), + effectiveOptions.getAdditionalBodyParams(), + effectiveOptions.getAdditionalQueryParams()) .map(response -> formatter.parseResponse(response, start)); } else { // Non-streaming mode return Flux.defer( () -> { try { - DashScopeResponse response = httpClient.call(request); + DashScopeResponse response = + httpClient.call( + request, + effectiveOptions.getAdditionalHeaders(), + effectiveOptions.getAdditionalBodyParams(), + effectiveOptions.getAdditionalQueryParams()); ChatResponse chatResponse = formatter.parseResponse(response, start); return Flux.just(chatResponse); } catch (Exception e) { diff --git a/agentscope-core/src/main/java/io/agentscope/core/model/DashScopeHttpClient.java b/agentscope-core/src/main/java/io/agentscope/core/model/DashScopeHttpClient.java index 8ede7f8b2..8549ea1ec 100644 --- a/agentscope-core/src/main/java/io/agentscope/core/model/DashScopeHttpClient.java +++ b/agentscope-core/src/main/java/io/agentscope/core/model/DashScopeHttpClient.java @@ -109,17 +109,6 @@ public DashScopeHttpClient(String apiKey) { this(apiKey, null); } - /** - * Make a synchronous API call. - * - * @param request the DashScope request - * @return the DashScope response - * @throws DashScopeHttpException if the request fails - */ - public DashScopeResponse call(DashScopeRequest request) { - return call(request, null, null, null); - } - /** * Make a synchronous API call with additional HTTP parameters. * @@ -180,16 +169,6 @@ public DashScopeResponse call( } } - /** - * Make a streaming API call. - * - * @param request the DashScope request - * @return a Flux of DashScope responses (one per SSE event) - */ - public Flux stream(DashScopeRequest request) { - return stream(request, null, null, null); - } - /** * Make a streaming API call with additional HTTP parameters. * diff --git a/agentscope-core/src/test/java/io/agentscope/core/model/DashScopeChatModelTest.java b/agentscope-core/src/test/java/io/agentscope/core/model/DashScopeChatModelTest.java index 2a0c8d489..ecba79149 100644 --- a/agentscope-core/src/test/java/io/agentscope/core/model/DashScopeChatModelTest.java +++ b/agentscope-core/src/test/java/io/agentscope/core/model/DashScopeChatModelTest.java @@ -28,11 +28,18 @@ import io.agentscope.core.formatter.dashscope.dto.DashScopeParameters; import io.agentscope.core.formatter.dashscope.dto.DashScopeRequest; import io.agentscope.core.message.Msg; +import io.agentscope.core.message.MsgRole; +import io.agentscope.core.message.TextBlock; import io.agentscope.core.model.test.ModelTestUtils; +import io.agentscope.core.model.transport.OkHttpTransport; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Tag; @@ -420,6 +427,95 @@ void testWithThinkingMode() { assertNotNull(thinkingModel); } + @Test + @DisplayName("DashScope chat model stream with additional headers and params") + void testDoStreamWithAdditionHeadersAndParams() throws Exception { + MockWebServer mockServer = new MockWebServer(); + mockServer.start(); + + mockServer.enqueue( + new MockResponse() + .setResponseCode(200) + .setBody("{\"request_id\":\"test\",\"output\":{\"choices\":[]}}") + .setHeader("Content-Type", "application/json")); + + DashScopeChatModel chatModel = + DashScopeChatModel.builder().apiKey(mockApiKey).modelName("qwen-plus").stream(true) + .enableThinking(true) + .enableSearch(true) + .baseUrl(mockServer.url("/").toString().replaceAll("/$", "")) + .httpTransport(OkHttpTransport.builder().build()) + .build(); + + chatModel + .doStream( + List.of( + Msg.builder() + .role(MsgRole.USER) + .content(TextBlock.builder().text("test").build()) + .build()), + List.of(), + GenerateOptions.builder() + .additionalHeaders(Map.of("custom", "custom-header")) + .additionalBodyParams(Map.of("custom", "custom-body")) + .additionalQueryParams(Map.of("custom", "custom-query")) + .build()) + .blockLast(); + + RecordedRequest recorded = mockServer.takeRequest(); + assertEquals("custom-header", recorded.getHeader("custom")); + assertEquals( + DashScopeHttpClient.TEXT_GENERATION_ENDPOINT + "?custom=custom-query", + recorded.getPath()); + assertTrue(recorded.getBody().readUtf8().contains("\"custom\":\"custom-body\"")); + + mockServer.close(); + } + + @Test + @DisplayName("DashScope chat model non-stream with additional headers and params") + void testDoNonStreamWithAdditionHeadersAndParams() throws Exception { + MockWebServer mockServer = new MockWebServer(); + mockServer.start(); + + mockServer.enqueue( + new MockResponse() + .setResponseCode(200) + .setBody("{\"request_id\":\"test\",\"output\":{\"choices\":[]}}") + .setHeader("Content-Type", "application/json")); + + DashScopeChatModel chatModel = + DashScopeChatModel.builder().apiKey(mockApiKey).modelName("qwen-plus").stream(true) + .stream(false) + .baseUrl(mockServer.url("/").toString().replaceAll("/$", "")) + .httpTransport(OkHttpTransport.builder().build()) + .build(); + + chatModel + .doStream( + List.of( + Msg.builder() + .role(MsgRole.USER) + .content(TextBlock.builder().text("test").build()) + .build()), + List.of(), + GenerateOptions.builder() + .additionalHeaders(Map.of("custom", "custom-header")) + .additionalBodyParams(Map.of("custom", "custom-body")) + .additionalQueryParams(Map.of("custom", "custom-query")) + .build()) + .blockLast(); + + RecordedRequest recorded = mockServer.takeRequest(); + assertEquals("custom-header", recorded.getHeader("custom")); + assertEquals( + DashScopeHttpClient.TEXT_GENERATION_ENDPOINT + "?custom=custom-query", + recorded.getPath()); + assertTrue(recorded.getBody().readUtf8().contains("\"custom\":\"custom-body\"")); + + mockServer.close(); + } + @Test @DisplayName("DashScope chat model apply thinking mode") void testApplyThinkingMode() { diff --git a/agentscope-core/src/test/java/io/agentscope/core/model/DashScopeHttpClientTest.java b/agentscope-core/src/test/java/io/agentscope/core/model/DashScopeHttpClientTest.java index c936c5d59..0d721c0c8 100644 --- a/agentscope-core/src/test/java/io/agentscope/core/model/DashScopeHttpClientTest.java +++ b/agentscope-core/src/test/java/io/agentscope/core/model/DashScopeHttpClientTest.java @@ -28,7 +28,9 @@ import io.agentscope.core.formatter.dashscope.dto.DashScopeRequest; import io.agentscope.core.formatter.dashscope.dto.DashScopeResponse; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -137,7 +139,7 @@ void testCallTextGenerationApi() throws Exception { DashScopeRequest request = createTestRequest("qwen-plus", "Hello"); - DashScopeResponse response = client.call(request); + DashScopeResponse response = client.call(request, null, null, null); assertNotNull(response); assertEquals("test-request-id", response.getRequestId()); @@ -181,7 +183,7 @@ void testCallMultimodalApi() throws Exception { DashScopeRequest request = createTestRequest("qwen-vl-max", "What's in this image?"); - DashScopeResponse response = client.call(request); + DashScopeResponse response = client.call(request, null, null, null); assertNotNull(response); assertEquals("multimodal-request-id", response.getRequestId()); @@ -209,7 +211,7 @@ void testStreamTextGenerationApi() { DashScopeRequest request = createTestRequest("qwen-plus", "Hi"); List responses = new ArrayList<>(); - StepVerifier.create(client.stream(request)) + StepVerifier.create(client.stream(request, null, null, null)) .recordWith(() -> responses) .expectNextCount(2) .verifyComplete(); @@ -239,7 +241,7 @@ void testStreamMultimodalApi() { DashScopeRequest request = createTestRequest("qwen-vl-max", "Describe image"); - StepVerifier.create(client.stream(request)) + StepVerifier.create(client.stream(request, null, null, null)) .expectNextMatches( r -> "I see" @@ -273,7 +275,7 @@ void testApiErrorHandling() { DashScopeHttpClient.DashScopeHttpException exception = assertThrows( DashScopeHttpClient.DashScopeHttpException.class, - () -> client.call(request)); + () -> client.call(request, null, null, null)); assertTrue(exception.getMessage().contains("Invalid API key")); } @@ -291,7 +293,7 @@ void testHttpErrorHandling() { DashScopeHttpClient.DashScopeHttpException exception = assertThrows( DashScopeHttpClient.DashScopeHttpException.class, - () -> client.call(request)); + () -> client.call(request, null, null, null)); assertEquals(500, exception.getStatusCode()); } @@ -322,7 +324,7 @@ void testRequestHeaders() throws Exception { .setHeader("Content-Type", "application/json")); DashScopeRequest request = createTestRequest("qwen-plus", "test"); - client.call(request); + client.call(request, null, null, null); RecordedRequest recorded = mockServer.takeRequest(); assertEquals("Bearer test-api-key", recorded.getHeader("Authorization")); @@ -339,7 +341,7 @@ void testStreamingRequestHeaders() throws Exception { .setHeader("Content-Type", "text/event-stream")); DashScopeRequest request = createTestRequest("qwen-plus", "test"); - client.stream(request).blockLast(); + client.stream(request, null, null, null).blockLast(); RecordedRequest recorded = mockServer.takeRequest(); assertEquals("enable", recorded.getHeader("X-DashScope-SSE")); @@ -482,6 +484,61 @@ void testHeaderOverride() throws Exception { assertEquals("application/json; charset=utf-8", recorded.getHeader("Content-Type")); } + @Test + void testCallAdditionalHeadersAndParams() throws Exception { + mockServer.enqueue( + new MockResponse() + .setResponseCode(200) + .setBody("{\"request_id\":\"test\",\"output\":{\"choices\":[]}}") + .setHeader("Content-Type", "application/json")); + + DashScopeRequest request = createTestRequest("qwen-plus", "test"); + // Override the Content-Type header + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put("custom", "custom-header"); + Map additionalBodyParams = new HashMap<>(); + additionalBodyParams.put("custom", "custom-body"); + Map additionalQueryParams = new HashMap<>(); + additionalQueryParams.put("custom", "custom-query"); + + client.call(request, additionalHeaders, additionalBodyParams, additionalQueryParams); + + RecordedRequest recorded = mockServer.takeRequest(); + assertEquals("custom-header", recorded.getHeader("custom")); + assertEquals( + DashScopeHttpClient.TEXT_GENERATION_ENDPOINT + "?custom=custom-query", + recorded.getPath()); + assertTrue(recorded.getBody().readUtf8().contains("\"custom\":\"custom-body\"")); + } + + @Test + void testStreamAdditionalHeadersAndParams() throws Exception { + mockServer.enqueue( + new MockResponse() + .setResponseCode(200) + .setBody("{\"request_id\":\"test\",\"output\":{\"choices\":[]}}") + .setHeader("Content-Type", "application/json")); + + DashScopeRequest request = createTestRequest("qwen-plus", "test"); + // Override the Content-Type header + Map additionalHeaders = new HashMap<>(); + additionalHeaders.put("custom", "custom-header"); + Map additionalBodyParams = new HashMap<>(); + additionalBodyParams.put("custom", "custom-value"); + Map additionalQueryParams = new HashMap<>(); + additionalQueryParams.put("custom", "custom-value"); + + client.stream(request, additionalHeaders, additionalBodyParams, additionalQueryParams) + .blockLast(); + + RecordedRequest recorded = mockServer.takeRequest(); + assertEquals("custom-header", recorded.getHeader("custom")); + assertEquals( + DashScopeHttpClient.TEXT_GENERATION_ENDPOINT + "?custom=custom-value", + recorded.getPath()); + assertTrue(recorded.getBody().readUtf8().contains("\"custom\":\"custom-value\"")); + } + // ==================== DashScopeHttpException Tests ==================== @Test