Skip to content

Commit 4da1711

Browse files
committed
Simplify logic
1 parent d9a47eb commit 4da1711

File tree

1 file changed

+49
-58
lines changed

1 file changed

+49
-58
lines changed
Lines changed: 49 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
package com.microsoft.openai.samples.rag.ask.approaches;
22

33
import com.azure.ai.openai.models.ChatCompletions;
4-
import com.azure.ai.openai.models.ChatCompletionsOptions;
5-
import com.azure.ai.openai.models.ChatMessage;
64
import com.microsoft.openai.samples.rag.approaches.ContentSource;
75
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
86
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
@@ -42,65 +40,45 @@ public PlainJavaAskApproach(FactsRetrieverProvider factsRetrieverProvider, OpenA
4240
*/
4341
@Override
4442
public RAGResponse run(String question, RAGOptions options) {
45-
return formChatCompletionArguments(
46-
question,
47-
options,
48-
(chatCompletionsOptions, groundedChatMessages, sources) -> {
49-
// STEP 3: Generate a contextual and content specific answer using the retrieve facts
50-
ChatCompletions chatCompletions = openAIProxy.getChatCompletions(chatCompletionsOptions);
43+
//Get instance of retriever based on the retrieval mode: hybryd, text, vectors.
44+
Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options);
45+
List<ContentSource> sources = factsRetriever.retrieveFromQuestion(question, options);
46+
LOGGER.info("Total {} sources found in cognitive search for keyword search query[{}]", sources.size(),
47+
question);
5148

52-
LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
53-
chatCompletions.getUsage().getPromptTokens(),
54-
chatCompletions.getUsage().getCompletionTokens(),
55-
chatCompletions.getUsage().getTotalTokens());
56-
57-
return new RAGResponse.Builder()
58-
.question(question)
59-
.prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages))
60-
.answer(chatCompletions.getChoices().get(0).getMessage().getContent())
61-
.sources(sources)
62-
.build();
63-
});
49+
var customPrompt = options.getPromptTemplate();
50+
var customPromptEmpty = (customPrompt == null) || (customPrompt != null && customPrompt.isEmpty());
51+
52+
//true will replace the default prompt. False will add custom prompt as suffix to the default prompt
53+
var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|");
54+
if (!replacePrompt && !customPromptEmpty) {
55+
customPrompt = customPrompt.substring(1);
56+
}
57+
58+
var answerQuestionChatTemplate = new AnswerQuestionChatTemplate(customPrompt, replacePrompt);
59+
60+
var groundedChatMessages = answerQuestionChatTemplate.getMessages(question, sources);
61+
var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(groundedChatMessages);
62+
63+
// STEP 3: Generate a contextual and content specific answer using the retrieve facts
64+
ChatCompletions chatCompletions = openAIProxy.getChatCompletions(chatCompletionsOptions);
65+
66+
LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
67+
chatCompletions.getUsage().getPromptTokens(),
68+
chatCompletions.getUsage().getCompletionTokens(),
69+
chatCompletions.getUsage().getTotalTokens());
70+
71+
return new RAGResponse.Builder()
72+
.question(question)
73+
.prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages))
74+
.answer(chatCompletions.getChoices().get(0).getMessage().getContent())
75+
.sources(sources)
76+
.build();
6477
}
6578

6679
@Override
6780
public Flux<RAGResponse> runStreaming(String question, RAGOptions options) {
68-
return formChatCompletionArguments(
69-
question,
70-
options,
71-
(chatCompletionsOptions, groundedChatMessages, sources) -> {
72-
Flux<ChatCompletions> completions = Flux.fromIterable(openAIProxy.getChatCompletionsStream(chatCompletionsOptions));
73-
return completions
74-
.flatMap(completion -> {
75-
LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
76-
completion.getUsage().getPromptTokens(),
77-
completion.getUsage().getCompletionTokens(),
78-
completion.getUsage().getTotalTokens());
79-
80-
return Flux.fromIterable(completion.getChoices())
81-
.map(choice -> new RAGResponse.Builder()
82-
.question(question)
83-
.prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages))
84-
.answer(choice.getMessage().getContent())
85-
.sources(sources)
86-
.build());
87-
});
88-
});
89-
}
90-
91-
private interface CompletionFunction<T> {
92-
T apply(
93-
ChatCompletionsOptions chatCompletionsOptions,
94-
List<ChatMessage> groundedChatMessages,
95-
List<ContentSource> sources
96-
);
97-
}
9881

99-
private <T> T formChatCompletionArguments(
100-
String question,
101-
RAGOptions options,
102-
CompletionFunction<T> completionFunction
103-
) {
10482
//Get instance of retriever based on the retrieval mode: hybryd, text, vectors.
10583
Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options);
10684
List<ContentSource> sources = factsRetriever.retrieveFromQuestion(question, options);
@@ -121,8 +99,21 @@ private <T> T formChatCompletionArguments(
12199
var groundedChatMessages = answerQuestionChatTemplate.getMessages(question, sources);
122100
var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(groundedChatMessages);
123101

124-
return completionFunction.apply(chatCompletionsOptions, groundedChatMessages, sources);
102+
Flux<ChatCompletions> completions = Flux.fromIterable(openAIProxy.getChatCompletionsStream(chatCompletionsOptions));
103+
return completions
104+
.flatMap(completion -> {
105+
LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
106+
completion.getUsage().getPromptTokens(),
107+
completion.getUsage().getCompletionTokens(),
108+
completion.getUsage().getTotalTokens());
109+
110+
return Flux.fromIterable(completion.getChoices())
111+
.map(choice -> new RAGResponse.Builder()
112+
.question(question)
113+
.prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages))
114+
.answer(choice.getMessage().getContent())
115+
.sources(sources)
116+
.build());
117+
});
125118
}
126-
127-
128119
}

0 commit comments

Comments
 (0)