|
1 | 1 | package com.microsoft.openai.samples.rag.ask.approaches;
|
2 | 2 |
|
3 |
| -import com.azure.ai.openai.models.Completions; |
4 |
| -import com.azure.ai.openai.models.CompletionsOptions; |
5 |
| -import com.azure.search.documents.util.SearchPagedIterable; |
| 3 | +import com.azure.ai.openai.models.ChatCompletions; |
6 | 4 | import com.microsoft.openai.samples.rag.approaches.ContentSource;
|
7 | 5 | import com.microsoft.openai.samples.rag.approaches.RAGApproach;
|
8 | 6 | import com.microsoft.openai.samples.rag.approaches.RAGOptions;
|
9 | 7 | import com.microsoft.openai.samples.rag.approaches.RAGResponse;
|
10 |
| -import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy; |
| 8 | +import com.microsoft.openai.samples.rag.common.ChatGPTUtils; |
11 | 9 | import com.microsoft.openai.samples.rag.proxy.OpenAIProxy;
|
| 10 | +import com.microsoft.openai.samples.rag.retrieval.CognitiveSearchRetriever; |
| 11 | +import com.microsoft.openai.samples.rag.retrieval.Retriever; |
12 | 12 | import org.slf4j.Logger;
|
13 | 13 | import org.slf4j.LoggerFactory;
|
| 14 | +import org.springframework.context.ApplicationContext; |
| 15 | +import org.springframework.context.ApplicationContextAware; |
14 | 16 | import org.springframework.stereotype.Component;
|
15 | 17 |
|
16 |
| -import java.util.ArrayList; |
17 |
| -import java.util.Collections; |
18 | 18 | import java.util.List;
|
19 | 19 |
|
20 | 20 | /**
|
|
23 | 23 | * (answer) with that prompt.
|
24 | 24 | */
|
25 | 25 | @Component
|
26 |
| -public class RetrieveThenReadApproach implements RAGApproach<String, RAGResponse> { |
| 26 | +public class RetrieveThenReadApproach implements RAGApproach<String, RAGResponse>, ApplicationContextAware { |
27 | 27 |
|
28 | 28 | private static final Logger LOGGER = LoggerFactory.getLogger(RetrieveThenReadApproach.class);
|
29 |
| - private final CognitiveSearchProxy cognitiveSearchProxy; |
| 29 | + private ApplicationContext applicationContext; |
30 | 30 | private final OpenAIProxy openAIProxy;
|
| 31 | + private final Retriever factsRetriever; |
31 | 32 |
|
32 |
| - public RetrieveThenReadApproach(CognitiveSearchProxy cognitiveSearchProxy, OpenAIProxy openAIProxy) { |
33 |
| - this.cognitiveSearchProxy = cognitiveSearchProxy; |
| 33 | + public RetrieveThenReadApproach(Retriever factsRetriever, OpenAIProxy openAIProxy) { |
| 34 | + this.factsRetriever = factsRetriever; |
34 | 35 | this.openAIProxy = openAIProxy;
|
35 | 36 | }
|
36 | 37 |
|
37 | 38 | /**
|
38 |
| - * @param questionOrConversation |
| 39 | + * @param question |
39 | 40 | * @param options
|
40 | 41 | * @return
|
41 | 42 | */
|
42 | 43 | @Override
|
43 |
| - public RAGResponse run(String questionOrConversation, RAGOptions options) { |
| 44 | + public RAGResponse run(String question, RAGOptions options) { |
44 | 45 | //TODO exception handling
|
45 |
| - SearchPagedIterable searchResults = getCognitiveSearchResults(questionOrConversation, options); |
46 | 46 |
|
47 |
| - List<ContentSource> sources = buildSourcesFromSearchResults(options, searchResults); |
| 47 | + Retriever factsRetriever = getFactsRetriever(options); |
| 48 | + List<ContentSource> sources = factsRetriever.retrieveFromQuestion(question, options); |
48 | 49 | LOGGER.info("Total {} sources found in cognitive search for keyword search query[{}]", sources.size(),
|
49 |
| - questionOrConversation); |
| 50 | + question); |
50 | 51 |
|
51 |
| - var retrieveThenReadPrompt = new SemanticSearchAskPrompt(sources, questionOrConversation); |
| 52 | + var customPrompt = options.getPromptTemplate(); |
| 53 | + var customPromptEmpty = (customPrompt == null) || (customPrompt != null && customPrompt.isEmpty()); |
52 | 54 |
|
53 |
| - var completionsOptions = buildCompletionsOptions(retrieveThenReadPrompt); |
| 55 | + //true will replace the default prompt. False will add custom prompt as suffix to the default prompt |
| 56 | + var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|"); |
| 57 | + if(!replacePrompt && !customPromptEmpty){ |
| 58 | + customPrompt = customPrompt.substring(1); |
| 59 | + } |
54 | 60 |
|
55 |
| - Completions completionsResults = openAIProxy.getCompletions(completionsOptions); |
| 61 | + var answerQuestionChatTemplate = new AnswerQuestionChatTemplate(customPrompt, replacePrompt); |
56 | 62 |
|
57 |
| - LOGGER.info("Completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]", |
58 |
| - completionsResults.getUsage().getPromptTokens(), |
59 |
| - completionsResults.getUsage().getCompletionTokens(), |
60 |
| - completionsResults.getUsage().getTotalTokens()); |
| 63 | + var groundedChatMessages = answerQuestionChatTemplate.getMessages(question,sources); |
| 64 | + var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(groundedChatMessages); |
| 65 | + |
| 66 | + // STEP 3: Generate a contextual and content specific answer using the retrieve facts |
| 67 | + ChatCompletions chatCompletions = openAIProxy.getChatCompletions(chatCompletionsOptions); |
| 68 | + |
| 69 | + LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]", |
| 70 | + chatCompletions.getUsage().getPromptTokens(), |
| 71 | + chatCompletions.getUsage().getCompletionTokens(), |
| 72 | + chatCompletions.getUsage().getTotalTokens()); |
61 | 73 |
|
62 | 74 | return new RAGResponse.Builder()
|
63 |
| - .prompt(retrieveThenReadPrompt.getFormattedPrompt()) |
64 |
| - .answer(completionsResults.getChoices().get(0).getText()) |
65 |
| - .sources(sources) |
66 |
| - .question(questionOrConversation) |
67 |
| - .build(); |
| 75 | + .question(question) |
| 76 | + .prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages)) |
| 77 | + .answer(chatCompletions.getChoices().get(0).getMessage().getContent()) |
| 78 | + .sources(sources) |
| 79 | + .build(); |
| 80 | + |
68 | 81 | }
|
69 | 82 |
|
70 |
| - private CompletionsOptions buildCompletionsOptions(SemanticSearchAskPrompt retrieveThenReadPrompt) { |
71 |
| - CompletionsOptions completionsOptions = new CompletionsOptions(new ArrayList<>(Collections.singletonList(retrieveThenReadPrompt.getFormattedPrompt()))); |
72 |
| - // Due to a potential bug when using JVM 17 and java openai SDK 1.0.0-beta.2, we need to provide default for all properties to avoid 404 bad Request on the server |
73 |
| - completionsOptions.setStop(List.of("\n")); |
74 |
| - return fillCommonCompletionsOptions(completionsOptions); |
| 83 | + /** |
| 84 | + * |
| 85 | + * @param options rag options containing search types(Cognitive Semantic Search, Cognitive Vector Search, Cognitive Hybrid Search, Semantic Kernel Memory) ) |
| 86 | + * @return retriever implementation |
| 87 | + */ |
| 88 | + private CognitiveSearchRetriever getFactsRetriever(RAGOptions options) { |
| 89 | + //default to Cognitive Semantic Search for MVP. |
| 90 | + return this.applicationContext.getBean(CognitiveSearchRetriever.class); |
| 91 | + |
75 | 92 | }
|
76 | 93 |
|
77 |
| - @Override |
78 |
| - public CognitiveSearchProxy getCognitiveSearchProxy() { |
79 |
| - return this.cognitiveSearchProxy; |
| 94 | + public void setApplicationContext(ApplicationContext applicationContext) { |
| 95 | + this.applicationContext = applicationContext; |
80 | 96 | }
|
81 | 97 |
|
82 | 98 | }
|
0 commit comments