From 76a11d94ad40a75d17ddcb54b2bde5253a13774b Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Wed, 24 Sep 2025 13:35:37 -0700 Subject: [PATCH] make MLSdkAsyncHttpResponseHandler return IllegalArgumentException (#4182) * make MLSdkAsyncHttpResponseHandler return IllegalArgumentException Signed-off-by: Brian Flores * empty commit to trigger CI Signed-off-by: Brian Flores --------- Signed-off-by: Brian Flores Co-authored-by: Dhrubo Saha (cherry picked from commit f83026c006be15217414e005f50c9cbecc54d10a) --- .../remote/MLSdkAsyncHttpResponseHandler.java | 2 + .../MLSdkAsyncHttpResponseHandlerTest.java | 59 ++++++++++++++----- 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java index 7da5830840..5383155662 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandler.java @@ -206,6 +206,8 @@ private void response() { ModelTensors tensors = processOutput(action, body, connector, scriptService, parameters, mlGuard); tensors.setStatusCode(statusCode); actionListener.onResponse(new Tuple<>(executionContext.getSequence(), tensors)); + } catch (IllegalArgumentException e) { + actionListener.onFailure(e); } catch (Exception e) { log.error("Failed to process response body: {}", body, e); actionListener.onFailure(new MLException("Fail to execute " + action + " in aws connector", e)); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java index a16e9f7851..4875731aaa 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/MLSdkAsyncHttpResponseHandlerTest.java @@ -5,6 +5,7 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -31,7 +32,6 @@ import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.connector.MLPostProcessFunction; -import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.script.ScriptService; import org.reactivestreams.Publisher; @@ -191,7 +191,7 @@ public void test_onError() { ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(captor.capture()); assert captor.getValue() instanceof OpenSearchStatusException; - assert captor.getValue().getMessage().equals("Error communicating with remote model: runtime exception"); + assertEquals("Error communicating with remote model: runtime exception", captor.getValue().getMessage()); } @Test @@ -209,7 +209,7 @@ public void test_onSubscribe() { public void test_onNext() { test_onSubscribe();// set the subscription to non-null. responseSubscriber.onNext(ByteBuffer.wrap("hello world".getBytes())); - assert mlSdkAsyncHttpResponseHandler.getResponseBody().toString().equals("hello world"); + assertEquals("hello world", mlSdkAsyncHttpResponseHandler.getResponseBody().toString()); } @Test @@ -221,7 +221,7 @@ public void test_MLResponseSubscriber_onError() { ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); verify(actionListener, times(1)).onFailure(captor.capture()); assert captor.getValue() instanceof OpenSearchStatusException; - assert captor.getValue().getMessage().equals("Remote service returned error status 500 with empty body"); + assertEquals("Remote service returned error status 500 with empty body", captor.getValue().getMessage()); } @Test @@ -283,7 +283,7 @@ public void test_onComplete_failed() { mlSdkAsyncHttpResponseHandler.onStream(stream); ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener, times(1)).onFailure(captor.capture()); - assert captor.getValue().getMessage().equals("Error from remote service: Model current status is: FAILED"); + assertEquals("Error from remote service: Model current status is: FAILED", captor.getValue().getMessage()); assert captor.getValue().status().getStatus() == 500; } @@ -302,7 +302,7 @@ public void test_onComplete_empty_response_body() { mlSdkAsyncHttpResponseHandler.onStream(stream); ArgumentCaptor captor = ArgumentCaptor.forClass(OpenSearchStatusException.class); verify(actionListener, times(1)).onFailure(captor.capture()); - assert captor.getValue().getMessage().equals("Remote service returned empty response body"); + assertEquals("Remote service returned empty response body", captor.getValue().getMessage()); } @Test @@ -380,14 +380,12 @@ public void test_onComplete_throttle_exception_onFailure() { ArgumentCaptor captor = ArgumentCaptor.forClass(RemoteConnectorThrottlingException.class); verify(actionListener, times(1)).onFailure(captor.capture()); - assert captor - .getValue() - .getMessage() - .equals( - "Error from remote service: The request was denied due to remote server throttling. " - + "To change the retry policy and behavior, please update the connector client_config." - ); assert captor.getValue().status().getStatus() == HttpStatusCode.BAD_REQUEST; + assertEquals( + "Error from remote service: The request was denied due to remote server throttling. " + + "To change the retry policy and behavior, please update the connector client_config.", + captor.getValue().getMessage() + ); } @Test @@ -416,8 +414,39 @@ public void test_onComplete_processOutputFail_onFailure() { }; mlSdkAsyncHttpResponseHandler.onStream(stream); - ArgumentCaptor captor = ArgumentCaptor.forClass(MLException.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(IllegalArgumentException.class); verify(actionListener, times(1)).onFailure(captor.capture()); - assert captor.getValue().getMessage().equals("Fail to execute PREDICT in aws connector"); + assertEquals("no PREDICT action found", captor.getValue().getMessage()); + } + + /** + * Asserts that IllegalArgumentException is propagated where post-processing function fails + * on response + */ + @Test + public void onComplete_InvalidEmbeddingBedRockPostProcessingOccurs_IllegalArgumentExceptionThrown() { + String invalidEmbeddingResponse = "{ \"embedding\": [[1]] }"; + + mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse); + Publisher stream = s -> { + try { + s.onSubscribe(mock(Subscription.class)); + s.onNext(ByteBuffer.wrap(invalidEmbeddingResponse.getBytes())); + s.onComplete(); + } catch (Throwable e) { + s.onError(e); + } + }; + mlSdkAsyncHttpResponseHandler.onStream(stream); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class); + verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); + + // Error message + assertEquals( + "BedrockEmbeddingPostProcessFunction exception message should match", + "The embedding should be a non-empty List containing Float values.", + exceptionCaptor.getValue().getMessage() + ); } }