Skip to content

Commit e18f249

Browse files
Fix prompt passing for Bedrock by passing a single string prompt for … (#1490)
* Fix prompt passing for Bedrock by passing a single string prompt for Bedrock models. (#1476) Signed-off-by: Austin Lee <[email protected]> * Add unit tests, apply Spotless. Signed-off-by: Austin Lee <[email protected]> * Check if systemPrompt is null. Signed-off-by: Austin Lee <[email protected]> * Address review comments. Signed-off-by: Austin Lee <[email protected]> --------- Signed-off-by: Austin Lee <[email protected]>
1 parent f4446cb commit e18f249

File tree

10 files changed

+281
-30
lines changed

10 files changed

+281
-30
lines changed

search-processors/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ GET /<index>/_search\?search_pipeline\=<search pipeline name>
4949
}
5050
```
5151

52+
To use this with Bedrock models, use "bedrock/" as a prefix for the "llm_model" parameters, e.g. "bedrock/anthropic".
53+
54+
The latest RAG processor has been tested with OpenAI's GPT 3.5 and 4 models and Bedrock's Anthropic Claude (v2) model only.
55+
5256
## Retrieval Augmented Generation response
5357
```
5458
{

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,5 @@ public class ChatCompletionInput {
4242
private int timeoutInSeconds;
4343
private String systemPrompt;
4444
private String userInstructions;
45+
private Llm.ModelProvider modelProvider;
4546
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java

Lines changed: 75 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public DefaultLlmImpl(String openSearchModelId, Client client) {
6363
}
6464

6565
@VisibleForTesting
66-
void setMlClient(MachineLearningInternalClient mlClient) {
66+
protected void setMlClient(MachineLearningInternalClient mlClient) {
6767
this.mlClient = mlClient;
6868
}
6969

@@ -76,19 +76,7 @@ void setMlClient(MachineLearningInternalClient mlClient) {
7676
@Override
7777
public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) {
7878

79-
Map<String, String> inputParameters = new HashMap<>();
80-
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
81-
String messages = PromptUtil
82-
.getChatCompletionPrompt(
83-
chatCompletionInput.getSystemPrompt(),
84-
chatCompletionInput.getUserInstructions(),
85-
chatCompletionInput.getQuestion(),
86-
chatCompletionInput.getChatHistory(),
87-
chatCompletionInput.getContexts()
88-
);
89-
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
90-
log.info("Messages to LLM: {}", messages);
91-
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(inputParameters).build();
79+
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
9280
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
9381
ActionFuture<MLOutput> future = mlClient.predict(this.openSearchModelId, mlInput);
9482
ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(chatCompletionInput.getTimeoutInSeconds() * 1000);
@@ -99,19 +87,83 @@ public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionI
9987

10088
// TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases.
10189

102-
List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES);
90+
return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap);
91+
}
92+
93+
protected Map<String, String> getInputParameters(ChatCompletionInput chatCompletionInput) {
94+
Map<String, String> inputParameters = new HashMap<>();
95+
96+
if (chatCompletionInput.getModelProvider() == ModelProvider.OPENAI) {
97+
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
98+
String messages = PromptUtil
99+
.getChatCompletionPrompt(
100+
chatCompletionInput.getSystemPrompt(),
101+
chatCompletionInput.getUserInstructions(),
102+
chatCompletionInput.getQuestion(),
103+
chatCompletionInput.getChatHistory(),
104+
chatCompletionInput.getContexts()
105+
);
106+
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
107+
log.info("Messages to LLM: {}", messages);
108+
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK) {
109+
inputParameters
110+
.put(
111+
"inputs",
112+
PromptUtil
113+
.buildSingleStringPrompt(
114+
chatCompletionInput.getSystemPrompt(),
115+
chatCompletionInput.getUserInstructions(),
116+
chatCompletionInput.getQuestion(),
117+
chatCompletionInput.getChatHistory(),
118+
chatCompletionInput.getContexts()
119+
)
120+
);
121+
} else {
122+
throw new IllegalArgumentException("Unknown/unsupported model provider: " + chatCompletionInput.getModelProvider());
123+
}
124+
125+
log.info("LLM input parameters: {}", inputParameters.toString());
126+
return inputParameters;
127+
}
128+
129+
protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, Map<String, ?> dataAsMap) {
130+
103131
List<Object> answers = null;
104132
List<String> errors = null;
105-
if (choices == null) {
106-
Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR);
107-
errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE));
133+
134+
if (provider == ModelProvider.OPENAI) {
135+
List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES);
136+
if (choices == null) {
137+
Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR);
138+
errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE));
139+
} else {
140+
Map firstChoiceMap = (Map) choices.get(0);
141+
log.info("Choices: {}", firstChoiceMap.toString());
142+
Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE);
143+
log
144+
.info(
145+
"role: {}, content: {}",
146+
message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE),
147+
message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)
148+
);
149+
answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
150+
}
151+
} else if (provider == ModelProvider.BEDROCK) {
152+
String response = (String) dataAsMap.get("completion");
153+
if (response != null) {
154+
answers = List.of(response);
155+
} else {
156+
Map error = (Map) dataAsMap.get("error");
157+
if (error != null) {
158+
errors = List.of((String) error.get("message"));
159+
} else {
160+
errors = List.of("Unknown error or response.");
161+
}
162+
}
108163
} else {
109-
Map firstChoiceMap = (Map) choices.get(0);
110-
log.info("Choices: {}", firstChoiceMap.toString());
111-
Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE);
112-
log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
113-
answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
164+
throw new IllegalArgumentException("Unknown/unsupported model provider: " + provider);
114165
}
166+
115167
return new ChatCompletionOutput(answers, errors);
116168
}
117169
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,11 @@
2222
*/
2323
public interface Llm {
2424

25+
// TODO Ensure the current implementation works with all models supported by Bedrock.
26+
enum ModelProvider {
27+
OPENAI,
28+
BEDROCK
29+
}
30+
2531
ChatCompletionOutput doChatCompletion(ChatCompletionInput input);
2632
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
*/
2828
public class LlmIOUtil {
2929

30+
private static final String BEDROCK_PROVIDER_PREFIX = "bedrock/";
31+
3032
public static ChatCompletionInput createChatCompletionInput(
3133
String llmModel,
3234
String question,
@@ -57,7 +59,19 @@ public static ChatCompletionInput createChatCompletionInput(
5759
List<String> contexts,
5860
int timeoutInSeconds
5961
) {
60-
61-
return new ChatCompletionInput(llmModel, question, chatHistory, contexts, timeoutInSeconds, systemPrompt, userInstructions);
62+
Llm.ModelProvider provider = Llm.ModelProvider.OPENAI;
63+
if (llmModel != null && llmModel.startsWith(BEDROCK_PROVIDER_PREFIX)) {
64+
provider = Llm.ModelProvider.BEDROCK;
65+
}
66+
return new ChatCompletionInput(
67+
llmModel,
68+
question,
69+
chatHistory,
70+
contexts,
71+
timeoutInSeconds,
72+
systemPrompt,
73+
userInstructions,
74+
provider
75+
);
6276
}
6377
}

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.ArrayList;
2121
import java.util.Collections;
2222
import java.util.List;
23+
import java.util.Locale;
2324

2425
import org.apache.commons.text.StringEscapeUtils;
2526
import org.opensearch.core.common.Strings;
@@ -54,6 +55,8 @@ public class PromptUtil {
5455

5556
private static final String roleUser = "user";
5657

58+
private static final String NEWLINE = "\\n";
59+
5760
public static String getQuestionRephrasingPrompt(String originalQuestion, List<Interaction> chatHistory) {
5861
return null;
5962
}
@@ -62,6 +65,8 @@ public static String getChatCompletionPrompt(String question, List<Interaction>
6265
return getChatCompletionPrompt(DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts);
6366
}
6467

68+
// TODO Currently, this is OpenAI specific. Change this to indicate as such or address it as part of
69+
// future prompt template management work.
6570
public static String getChatCompletionPrompt(
6671
String systemPrompt,
6772
String userInstructions,
@@ -87,6 +92,48 @@ enum ChatRole {
8792
}
8893
}
8994

95+
public static String buildSingleStringPrompt(
96+
String systemPrompt,
97+
String userInstructions,
98+
String question,
99+
List<Interaction> chatHistory,
100+
List<String> contexts
101+
) {
102+
if (Strings.isNullOrEmpty(systemPrompt) && Strings.isNullOrEmpty(userInstructions)) {
103+
systemPrompt = DEFAULT_SYSTEM_PROMPT;
104+
}
105+
106+
StringBuilder bldr = new StringBuilder();
107+
108+
if (!Strings.isNullOrEmpty(systemPrompt)) {
109+
bldr.append(systemPrompt);
110+
bldr.append(NEWLINE);
111+
}
112+
if (!Strings.isNullOrEmpty(userInstructions)) {
113+
bldr.append(userInstructions);
114+
bldr.append(NEWLINE);
115+
}
116+
117+
for (int i = 0; i < contexts.size(); i++) {
118+
bldr.append("SEARCH RESULT " + (i + 1) + ": " + contexts.get(i));
119+
bldr.append(NEWLINE);
120+
}
121+
if (!chatHistory.isEmpty()) {
122+
// The oldest interaction first
123+
List<Message> messages = Messages.fromInteractions(chatHistory).getMessages();
124+
Collections.reverse(messages);
125+
messages.forEach(m -> {
126+
bldr.append(m.toString());
127+
bldr.append(NEWLINE);
128+
});
129+
130+
}
131+
bldr.append("QUESTION: " + question);
132+
bldr.append(NEWLINE);
133+
134+
return bldr.toString();
135+
}
136+
90137
@VisibleForTesting
91138
static String buildMessageParameter(
92139
String systemPrompt,
@@ -110,7 +157,6 @@ static String buildMessageParameter(
110157
}
111158
if (!chatHistory.isEmpty()) {
112159
// The oldest interaction first
113-
// Collections.reverse(chatHistory);
114160
List<Message> messages = Messages.fromInteractions(chatHistory).getMessages();
115161
Collections.reverse(messages);
116162
messages.forEach(m -> messageArray.add(m.toJson()));
@@ -163,6 +209,8 @@ public static Messages fromInteractions(final List<Interaction> interactions) {
163209
}
164210
}
165211

212+
// TODO This is OpenAI specific. Either change this to OpenAiMessage or have it handle
213+
// vendor specific messages.
166214
static class Message {
167215

168216
private final static String MESSAGE_FIELD_ROLE = "role";
@@ -186,6 +234,7 @@ public Message(ChatRole chatRole, String content) {
186234
}
187235

188236
public void setChatRole(ChatRole chatRole) {
237+
this.chatRole = chatRole;
189238
json.remove(MESSAGE_FIELD_ROLE);
190239
json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(chatRole.getName()));
191240
}
@@ -199,5 +248,10 @@ public void setContent(String content) {
199248
public JsonObject toJson() {
200249
return json;
201250
}
251+
252+
@Override
253+
public String toString() {
254+
return String.format(Locale.ROOT, "%s: %s", chatRole.getName(), content);
255+
}
202256
}
203257
}

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ public void testCtor() {
4141
Collections.emptyList(),
4242
0,
4343
systemPrompt,
44-
userInstructions
44+
userInstructions,
45+
Llm.ModelProvider.OPENAI
4546
);
4647

4748
assertNotNull(input);
@@ -70,7 +71,16 @@ public void testGettersSetters() {
7071
)
7172
);
7273
List<String> contexts = List.of("result1", "result2");
73-
ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts, 0, systemPrompt, userInstructions);
74+
ChatCompletionInput input = new ChatCompletionInput(
75+
model,
76+
question,
77+
history,
78+
contexts,
79+
0,
80+
systemPrompt,
81+
userInstructions,
82+
Llm.ModelProvider.OPENAI
83+
);
7484
assertEquals(model, input.getModel());
7585
assertEquals(question, input.getQuestion());
7686
assertEquals(history.get(0).getConversationId(), input.getChatHistory().get(0).getConversationId());

0 commit comments

Comments
 (0)