Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ private void response() {
ModelTensors tensors = processOutput(action, body, connector, scriptService, parameters, mlGuard);
tensors.setStatusCode(statusCode);
actionListener.onResponse(new Tuple<>(executionContext.getSequence(), tensors));
} catch (IllegalArgumentException e) {
actionListener.onFailure(e);
} catch (Exception e) {
log.error("Failed to process response body: {}", body, e);
actionListener.onFailure(new MLException("Fail to execute " + action + " in aws connector", e));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -31,7 +32,6 @@
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.connector.MLPostProcessFunction;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.script.ScriptService;
import org.reactivestreams.Publisher;
Expand Down Expand Up @@ -191,7 +191,7 @@ public void test_onError() {
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(captor.capture());
assert captor.getValue() instanceof OpenSearchStatusException;
assert captor.getValue().getMessage().equals("Error communicating with remote model: runtime exception");
assertEquals("Error communicating with remote model: runtime exception", captor.getValue().getMessage());
}

@Test
Expand All @@ -209,7 +209,7 @@ public void test_onSubscribe() {
public void test_onNext() {
test_onSubscribe();// set the subscription to non-null.
responseSubscriber.onNext(ByteBuffer.wrap("hello world".getBytes()));
assert mlSdkAsyncHttpResponseHandler.getResponseBody().toString().equals("hello world");
assertEquals("hello world", mlSdkAsyncHttpResponseHandler.getResponseBody().toString());
}

@Test
Expand All @@ -221,7 +221,7 @@ public void test_MLResponseSubscriber_onError() {
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue() instanceof OpenSearchStatusException;
assert captor.getValue().getMessage().equals("Remote service returned error status 500 with empty body");
assertEquals("Remote service returned error status 500 with empty body", captor.getValue().getMessage());
}

@Test
Expand Down Expand Up @@ -283,7 +283,7 @@ public void test_onComplete_failed() {
mlSdkAsyncHttpResponseHandler.onStream(stream);
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("Error from remote service: Model current status is: FAILED");
assertEquals("Error from remote service: Model current status is: FAILED", captor.getValue().getMessage());
assert captor.getValue().status().getStatus() == 500;
}

Expand All @@ -302,7 +302,7 @@ public void test_onComplete_empty_response_body() {
mlSdkAsyncHttpResponseHandler.onStream(stream);
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("Remote service returned empty response body");
assertEquals("Remote service returned empty response body", captor.getValue().getMessage());
}

@Test
Expand Down Expand Up @@ -380,14 +380,12 @@ public void test_onComplete_throttle_exception_onFailure() {

ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(RemoteConnectorThrottlingException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor
.getValue()
.getMessage()
.equals(
"Error from remote service: The request was denied due to remote server throttling. "
+ "To change the retry policy and behavior, please update the connector client_config."
);
assert captor.getValue().status().getStatus() == HttpStatusCode.BAD_REQUEST;
assertEquals(
"Error from remote service: The request was denied due to remote server throttling. "
+ "To change the retry policy and behavior, please update the connector client_config.",
captor.getValue().getMessage()
);
}

@Test
Expand Down Expand Up @@ -416,8 +414,39 @@ public void test_onComplete_processOutputFail_onFailure() {
};
mlSdkAsyncHttpResponseHandler.onStream(stream);

ArgumentCaptor<MLException> captor = ArgumentCaptor.forClass(MLException.class);
ArgumentCaptor<IllegalArgumentException> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("Fail to execute PREDICT in aws connector");
assertEquals("no PREDICT action found", captor.getValue().getMessage());
}

/**
* Asserts that IllegalArgumentException is propagated where post-processing function fails
* on response
*/
@Test
public void onComplete_InvalidEmbeddingBedRockPostProcessingOccurs_IllegalArgumentExceptionThrown() {
String invalidEmbeddingResponse = "{ \"embedding\": [[1]] }";

mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse);
Publisher<ByteBuffer> stream = s -> {
try {
s.onSubscribe(mock(Subscription.class));
s.onNext(ByteBuffer.wrap(invalidEmbeddingResponse.getBytes()));
s.onComplete();
} catch (Throwable e) {
s.onError(e);
}
};
mlSdkAsyncHttpResponseHandler.onStream(stream);

ArgumentCaptor<IllegalArgumentException> exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class);
verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());

// Error message
assertEquals(
"BedrockEmbeddingPostProcessFunction exception message should match",
"The embedding should be a non-empty List containing Float values.",
exceptionCaptor.getValue().getMessage()
);
}
}
Loading