Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,40 @@

package org.opensearch.ml.rest;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;

import java.io.IOException;
import java.lang.reflect.Field;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import org.junit.Before;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.MLResultDataType;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.searchpipelines.questionanswering.generative.client.MachineLearningInternalClient;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.DefaultLlmImpl;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;

import lombok.SneakyThrows;
import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -82,6 +103,110 @@ public void test_bedrock_embedding_model() throws Exception {
}
}

public void testChatCompletionBedrockContentFormat() throws Exception {
Map<String, Object> response = Map.of("content", List.of(Map.of("text", "Claude V3 response text")));

Map<String, Object> result = invokeBedrockInference(response);

assertTrue(result.containsKey("answers"));
assertEquals("Claude V3 response text", ((List<?>) result.get("answers")).get(0));
}

private static void injectMlClient(DefaultLlmImpl connector, Object mlClient) {
try {
Field field = null;
// Try common field names. Adjust if the actual field is named differently.
try {
field = DefaultLlmImpl.class.getDeclaredField("mlClient");
} catch (NoSuchFieldException e) {
// fallback if different field name
field = DefaultLlmImpl.class.getDeclaredField("client");
}
field.setAccessible(true);
field.set(connector, mlClient);
} catch (ReflectiveOperationException e) {
throw new RuntimeException("Failed to inject mlClient into DefaultLlmImpl", e);
}
}

private Map<String, Object> invokeBedrockInference(Map<String, Object> mockResponse) throws Exception {
// Create DefaultLlmImpl and mock ML client
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", null); // Use getClient() from MLCommonsRestTestCase
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
injectMlClient(connector, mlClient);

// Wrap mockResponse inside a ModelTensor -> ModelTensors -> ModelTensorOutput -> MLOutput
ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, mockResponse);
ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor))));
// Do NOT depend on ActionFuture return path; instead drive the async listener directly.

// Make asynchronous predict(...) call invoke the ActionListener with our mlOutput
doAnswer(invocation -> {
@SuppressWarnings("unchecked")
ActionListener<MLOutput> listener = (ActionListener<MLOutput>) invocation.getArguments()[2];
// Simulate successful ML response
listener.onResponse(mlOutput);
return null;
}).when(mlClient).predict(any(), any(), any());

// Prepare input (use BEDROCK provider so bedrock branch is taken)
ChatCompletionInput input = new ChatCompletionInput(
"bedrock/model",
"question",
Collections.emptyList(),
Collections.emptyList(),
0,
"prompt",
"instructions",
Llm.ModelProvider.BEDROCK,
null,
null
);

// Synchronously wait for callback result
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<Map<String, Object>> resultRef = new AtomicReference<>();

connector.doChatCompletion(input, new ActionListener<>() {
@Override
public void onResponse(ChatCompletionOutput output) {
Map<String, Object> map = new HashMap<>();
map.put("answers", output.getAnswers());
map.put("errors", output.getErrors());
resultRef.set(map);
latch.countDown();
}

@Override
public void onFailure(Exception e) {
Map<String, Object> map = new HashMap<>();
map.put("answers", Collections.emptyList());
map.put("errors", List.of(e.getMessage()));
resultRef.set(map);
latch.countDown();
}
});

boolean completed = latch.await(5, TimeUnit.SECONDS);
if (!completed) {
throw new RuntimeException("Timed out waiting for doChatCompletion callback");
}
return resultRef.get();
}

private void validateErrorOutput(String msg, Map<String, Object> output, String expectedError) {
assertTrue(msg, output.containsKey("error"));
Object error = output.get("error");

if (error instanceof Map) {
assertEquals(msg, expectedError, ((Map<?, ?>) error).get("message"));
} else if (error instanceof String) {
assertEquals(msg, expectedError, error);
} else {
fail("Unexpected error format: " + error.getClass());
}
}

private void validateOutput(String errorMsg, Map<String, Object> output) {
assertTrue(errorMsg, output.containsKey("output"));
assertTrue(errorMsg, output.get("output") instanceof List);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,26 +470,6 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
+ " }\n"
+ "}";

private static final String BM25_SEARCH_REQUEST_WITH_CONVO_WITH_LLM_RESPONSE_TEMPLATE = "{\n"
+ " \"_source\": [\"%s\"],\n"
+ " \"query\" : {\n"
+ " \"match\": {\"%s\": \"%s\"}\n"
+ " },\n"
+ " \"ext\": {\n"
+ " \"generative_qa_parameters\": {\n"
+ " \"llm_model\": \"%s\",\n"
+ " \"llm_question\": \"%s\",\n"
+ " \"memory_id\": \"%s\",\n"
+ " \"system_prompt\": \"%s\",\n"
+ " \"user_instructions\": \"%s\",\n"
+ " \"context_size\": %d,\n"
+ " \"message_size\": %d,\n"
+ " \"timeout\": %d,\n"
+ " \"llm_response_field\": \"%s\"\n"
+ " }\n"
+ " }\n"
+ "}";

private static final String BM25_SEARCH_REQUEST_WITH_CONVO_AND_IMAGE_TEMPLATE = "{\n"
+ " \"_source\": [\"%s\"],\n"
+ " \"query\" : {\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ public class DefaultLlmImpl implements Llm {
private static final String CONNECTOR_OUTPUT_MESSAGE_ROLE = "role";
private static final String CONNECTOR_OUTPUT_MESSAGE_CONTENT = "content";
private static final String CONNECTOR_OUTPUT_ERROR = "error";
private static final String BEDROCK_COMPLETION_FIELD = "completion";
private static final String BEDROCK_CONTENT_FIELD = "content";
private static final String BEDROCK_TEXT_FIELD = "text";

private final String openSearchModelId;

Expand Down Expand Up @@ -191,8 +194,38 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider,
answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
}
} else if (provider == ModelProvider.BEDROCK) {
answerField = "completion";
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
// Handle Bedrock model responses (supports both legacy completion and newer content/text formats)

Object contentObj = dataAsMap.get(BEDROCK_CONTENT_FIELD);
if (contentObj == null) {
// Legacy completion-style format
Object completion = dataAsMap.get(BEDROCK_COMPLETION_FIELD);
if (completion != null) {
answers.add(completion.toString());
} else {
errors.add("Unsupported Bedrock response format: " + dataAsMap.keySet());
log.error("Unknown Bedrock response format: {}", dataAsMap);
}
} else {
// Fail-fast checks for new content/text format
if (!(contentObj instanceof List<?> contentList)) {
errors.add("Unexpected type for '" + BEDROCK_CONTENT_FIELD + "' in Bedrock response.");
} else if (contentList.isEmpty()) {
errors.add("Empty content list in Bedrock response.");
} else {
Object first = contentList.get(0);
if (!(first instanceof Map<?, ?> firstMap)) {
errors.add("Unexpected content format in Bedrock response.");
} else {
Object text = firstMap.get(BEDROCK_TEXT_FIELD);
if (text == null) {
errors.add("Bedrock content response missing '" + BEDROCK_TEXT_FIELD + "' field.");
} else {
answers.add(text.toString());
}
}
}
}
} else if (provider == ModelProvider.COHERE) {
answerField = "text";
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
Expand Down
Loading
Loading