diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java index bf09ed9d24..4a07b79cab 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java @@ -17,8 +17,8 @@ import static org.opensearch.ml.utils.RestActionUtils.isAsync; import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; +import java.io.ByteArrayOutputStream; import java.io.IOException; -import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.util.LinkedHashMap; import java.util.List; @@ -48,7 +48,6 @@ import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; @@ -158,10 +157,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client ); channel.prepareResponse(RestStatus.OK, headers); - Flux.from(channel).ofType(HttpChunk.class).concatMap(chunk -> { - final CompletableFuture future = new CompletableFuture<>(); + Flux.from(channel).ofType(HttpChunk.class).collectList().flatMap(chunks -> { try { - MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, chunk.content()); + BytesReference completeContent = combineChunks(chunks); + MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, completeContent); + + final CompletableFuture future = new CompletableFuture<>(); StreamTransportResponseHandler handler = new StreamTransportResponseHandler() { @Override public void handleStreamResponse(StreamTransportResponse streamResponse) { @@ -214,19 +215,23 @@ public MLTaskResponse read(StreamInput in) throws IOException { handler ); - } catch (IOException e) { - throw new MLException("Got an exception in flux.", e); + return Mono.fromCompletionStage(future); + } catch (Exception e) { + log.error("Failed to parse or process request", e); + return Mono.error(e); } - - return Mono.fromCompletionStage(future); - }).doOnNext(channel::sendChunk).onErrorComplete(ex -> { - // Error handling + }).doOnNext(channel::sendChunk).onErrorResume(ex -> { + log.error("Error occurred", ex); try { - channel.sendResponse(new BytesRestResponse(channel, (Exception) ex)); - return true; - } catch (final IOException e) { - throw new UncheckedIOException(e); + String errorMessage = ex instanceof IOException + ? "Failed to parse request: " + ex.getMessage() + : "Error processing request: " + ex.getMessage(); + HttpChunk errorChunk = createHttpChunk("data: {\"error\": \"" + errorMessage.replace("\"", "\\\"") + "\"}\n\n", true); + channel.sendChunk(errorChunk); + } catch (Exception e) { + log.error("Failed to send error chunk", e); } + return Mono.empty(); }).subscribe(); }; @@ -402,6 +407,20 @@ private String extractTensorResult(MLTaskResponse response, String tensorName) { return Map.of(); } + @VisibleForTesting + BytesReference combineChunks(List chunks) { + try { + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); + for (HttpChunk chunk : chunks) { + chunk.content().writeTo(buffer); + } + return BytesReference.fromByteBuffer(ByteBuffer.wrap(buffer.toByteArray())); + } catch (IOException e) { + log.error("Failed to combine chunks", e); + throw new OpenSearchStatusException("Failed to combine request chunks", RestStatus.INTERNAL_SERVER_ERROR, e); + } + } + private HttpChunk createHttpChunk(String sseData, boolean isLast) { BytesReference bytesRef = BytesReference.fromByteBuffer(ByteBuffer.wrap(sseData.getBytes())); return new HttpChunk() { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteStreamActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteStreamActionTests.java index a6c6fc9166..6263a1bf28 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteStreamActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteStreamActionTests.java @@ -33,7 +33,9 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.rest.RestStatus; +import org.opensearch.http.HttpChunk; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.agent.LLMSpec; @@ -302,4 +304,71 @@ public void testGetRequestAgentFrameworkDisabled() { when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); } + + @Test + public void testCombineChunksWithSingleChunk() { + String testContent = "{\"parameters\":{\"question\":\"test\"}}"; + BytesArray bytesArray = new BytesArray(testContent); + + HttpChunk mockChunk = mock(HttpChunk.class); + when(mockChunk.content()).thenReturn(bytesArray); + + BytesReference result = restAction.combineChunks(List.of(mockChunk)); + + assertNotNull(result); + assertEquals(testContent, result.utf8ToString()); + } + + @Test + public void testCombineChunksWithMultipleChunks() { + String chunk1Content = "{\"parameters\":"; + String chunk2Content = "{\"question\":"; + String chunk3Content = "\"test\"}}"; + + BytesArray bytes1 = new BytesArray(chunk1Content); + BytesArray bytes2 = new BytesArray(chunk2Content); + BytesArray bytes3 = new BytesArray(chunk3Content); + + HttpChunk mockChunk1 = mock(HttpChunk.class); + HttpChunk mockChunk2 = mock(HttpChunk.class); + HttpChunk mockChunk3 = mock(HttpChunk.class); + + when(mockChunk1.content()).thenReturn(bytes1); + when(mockChunk2.content()).thenReturn(bytes2); + when(mockChunk3.content()).thenReturn(bytes3); + + BytesReference result = restAction.combineChunks(List.of(mockChunk1, mockChunk2, mockChunk3)); + + assertNotNull(result); + String expectedContent = chunk1Content + chunk2Content + chunk3Content; + assertEquals(expectedContent, result.utf8ToString()); + } + + @Test + public void testCombineChunksWithEmptyList() { + BytesReference result = restAction.combineChunks(List.of()); + + assertNotNull(result); + assertEquals(0, result.length()); + } + + @Test + public void testCombineChunksWithLargeContent() { + StringBuilder largeContent = new StringBuilder(); + for (int i = 0; i < 1000; i++) { + largeContent.append("chunk").append(i).append(","); + } + String content = largeContent.toString(); + + BytesArray bytesArray = new BytesArray(content); + + HttpChunk mockChunk = mock(HttpChunk.class); + when(mockChunk.content()).thenReturn(bytesArray); + + BytesReference result = restAction.combineChunks(List.of(mockChunk)); + + assertNotNull(result); + assertEquals(content.length(), result.length()); + assertEquals(content, result.utf8ToString()); + } }