Skip to content

Commit 2d271ea

Browse files
committed
merge PR from dsibilio and upgrade all approaches to use chat completion api
1 parent a1b9ec9 commit 2d271ea

29 files changed

+527
-853
lines changed

.github/workflows/app-ci.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ on:
99
branches:
1010
- app-ci-github-actions
1111
- main
12+
pull_request:
13+
branches: [ main ]
1214
workflow_dispatch:
1315

1416
jobs:
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: Validate AZD template
2+
on:
3+
push:
4+
branches: [ main ]
5+
paths:
6+
- "infra/**"
7+
pull_request:
8+
branches: [ main ]
9+
paths:
10+
- "infra/**"
11+
12+
jobs:
13+
build:
14+
15+
runs-on: ubuntu-latest
16+
steps:
17+
- name: Checkout
18+
uses: actions/checkout@v3
19+
20+
- name: Build Bicep for linting
21+
uses: azure/CLI@v1
22+
with:
23+
inlineScript: az config set bicep.use_binary_from_path=false && az bicep build -f infra/main.bicep --stdout
24+
25+
- name: Run Microsoft Security DevOps Analysis
26+
uses: microsoft/security-devops-action@preview
27+
id: msdo
28+
continue-on-error: true
29+
with:
30+
tools: templateanalyzer
31+
32+
- name: Upload alerts to Security tab
33+
uses: github/codeql-action/upload-sarif@v2
34+
if: github.repository == 'Azure-Samples/azure-search-openai-demo'
35+
with:
36+
sarif_file: ${{ steps.msdo.outputs.sarifFile }}
Lines changed: 0 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,11 @@
11
package com.microsoft.openai.samples.rag.approaches;
22

3-
import com.azure.ai.openai.models.CompletionsOptions;
4-
import com.azure.core.util.Context;
5-
import com.azure.search.documents.SearchDocument;
6-
import com.azure.search.documents.models.*;
7-
import com.azure.search.documents.util.SearchPagedIterable;
8-
import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy;
9-
10-
import java.util.ArrayList;
11-
import java.util.HashMap;
12-
import java.util.List;
13-
import java.util.Optional;
14-
153
public interface RAGApproach<I, O> {
164

175
O run(I questionOrConversation, RAGOptions options);
186

19-
CognitiveSearchProxy getCognitiveSearchProxy();
20-
21-
default CompletionsOptions fillCommonCompletionsOptions(CompletionsOptions completionsOptions) {
22-
completionsOptions.setMaxTokens(1024);
23-
completionsOptions.setTemperature(0.3);
24-
completionsOptions.setLogitBias(new HashMap<>());
25-
completionsOptions.setEcho(false);
26-
completionsOptions.setN(1);
27-
completionsOptions.setStream(false);
28-
completionsOptions.setUser("search-openai-demo-java");
29-
completionsOptions.setPresencePenalty(0.0);
30-
completionsOptions.setFrequencyPenalty(0.0);
31-
completionsOptions.setBestOf(1);
32-
return completionsOptions;
33-
}
34-
35-
default SearchOptions buildSearchOptions(RAGOptions options) {
36-
var searchOptions = new SearchOptions();
37-
38-
Optional.ofNullable(options.getTop()).ifPresentOrElse(
39-
searchOptions::setTop,
40-
() -> searchOptions.setTop(3));
41-
Optional.ofNullable(options.getExcludeCategory())
42-
.ifPresentOrElse(
43-
value -> searchOptions.setFilter("category ne '%s'".formatted(value.replace("'", "''"))),
44-
() -> searchOptions.setFilter(null));
45-
46-
Optional.ofNullable(options.isSemanticRanker()).ifPresent(isSemanticRanker -> {
47-
if (isSemanticRanker) {
48-
searchOptions.setQueryType(QueryType.SEMANTIC);
49-
searchOptions.setQueryLanguage(QueryLanguage.EN_US);
50-
searchOptions.setSpeller(QuerySpellerType.LEXICON);
51-
searchOptions.setSemanticConfigurationName("default");
52-
searchOptions.setQueryCaption(QueryCaptionType.EXTRACTIVE);
53-
searchOptions.setQueryCaptionHighlightEnabled(false);
54-
}
55-
});
56-
57-
return searchOptions;
58-
}
59-
60-
default SearchPagedIterable getCognitiveSearchResults(String question, RAGOptions options) {
61-
return getCognitiveSearchProxy().search(question, buildSearchOptions(options), Context.NONE);
62-
}
63-
64-
default List<ContentSource> buildSourcesFromSearchResults(RAGOptions options, SearchPagedIterable searchResults) {
65-
List<ContentSource> sources = new ArrayList<>();
66-
67-
searchResults.iterator().forEachRemaining(result ->
68-
{
69-
var searchDocument = result.getDocument(SearchDocument.class);
70-
71-
/*
72-
If captions is enabled the content source is taken from the captions generated by the semantic ranker.
73-
Captions are appended sequentially and separated by a dot.
74-
*/
75-
if(options.isSemanticCaptions()) {
76-
StringBuilder sourcesContentBuffer = new StringBuilder();
777

78-
result.getCaptions().forEach(caption -> sourcesContentBuffer.append(caption.getText()).append("."));
798

80-
sources.add(new ContentSource((String)searchDocument.get("sourcepage"), sourcesContentBuffer.toString()));
81-
} else {
82-
//If captions is disabled the content source is taken from the cognitive search index field "content"
83-
sources.add(new ContentSource((String) searchDocument.get("sourcepage"), (String) searchDocument.get("content")));
84-
}
85-
});
869

87-
return sources;
88-
}
8910

9011
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package com.microsoft.openai.samples.rag.ask.approaches;
2+
3+
import com.azure.ai.openai.models.ChatMessage;
4+
import com.azure.ai.openai.models.ChatRole;
5+
import com.microsoft.openai.samples.rag.approaches.ContentSource;
6+
import java.util.ArrayList;
7+
import java.util.List;
8+
9+
public class AnswerQuestionChatTemplate {
10+
11+
private final List<ChatMessage> conversationHistory = new ArrayList<>();
12+
13+
private String customPrompt = "";
14+
private String systemMessage;
15+
private Boolean replacePrompt = false;
16+
17+
private static final String SYSTEM_CHAT_MESSAGE_TEMPLATE = """
18+
You are an intelligent assistant helping Contoso Inc employees with their healthcare plan questions and employee handbook questions.
19+
Use 'you' to refer to the individual asking the questions even if they ask with 'I'.
20+
Answer the following question using only the data provided in the sources below.
21+
For tabular information return it as an html table. Do not return markdown format.
22+
Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response.
23+
If you cannot answer using the sources below, say you don't know. Use below example to answer
24+
%s
25+
""" ;
26+
27+
private static final String FEW_SHOT_USER_MESSAGE = """
28+
What is the deductible for the employee plan for a visit to Overlake in Bellevue?'
29+
Sources:
30+
info1.txt: deductibles depend on whether you are in-network or out-of-network. In-network deductibles are $500 for employee and $1000 for family. Out-of-network deductibles are $1000 for employee and $2000 for family.
31+
info2.pdf: Overlake is in-network for the employee plan.
32+
info3.pdf: Overlake is the name of the area that includes a park and ride near Bellevue.
33+
info4.pdf: In-network institutions include Overlake, Swedish and others in the region
34+
""";
35+
private static final String FEW_SHOT_ASSISTANT_MESSAGE = """
36+
In-network deductibles are $500 for employee and $1000 for family [info1.txt] and Overlake is in-network for the employee plan [info2.pdf][info4.pdf].
37+
""";
38+
39+
/**
40+
*
41+
* @param conversation conversation history
42+
* @param sources domain specific sources to be used in the prompt
43+
* @param customPrompt custom prompt to be injected in the existing promptTemplate or used to replace it
44+
* @param replacePrompt if true, the customPrompt will replace the default promptTemplate, otherwise it will be appended
45+
* to the default promptTemplate in the predefined section
46+
*/
47+
48+
private static final String GROUNDED_USER_QUESTION_TEMPLATE = """
49+
%s
50+
Sources:
51+
%s
52+
""";
53+
public AnswerQuestionChatTemplate( String customPrompt, Boolean replacePrompt) {
54+
55+
if(replacePrompt && (customPrompt == null || customPrompt.isEmpty()))
56+
throw new IllegalStateException("customPrompt cannot be null or empty when replacePrompt is true");
57+
58+
this.replacePrompt = replacePrompt;
59+
this.customPrompt = customPrompt == null ? "" : customPrompt;
60+
61+
62+
if(this.replacePrompt){
63+
this.systemMessage = customPrompt;
64+
} else {
65+
this.systemMessage = SYSTEM_CHAT_MESSAGE_TEMPLATE.formatted(this.customPrompt);
66+
}
67+
68+
//Add system message
69+
ChatMessage chatSystemMessage = new ChatMessage(ChatRole.SYSTEM);
70+
chatSystemMessage.setContent(systemMessage);
71+
72+
this.conversationHistory.add(chatSystemMessage);
73+
74+
//Add few shoot learning with chat
75+
ChatMessage fewShotUserMessage = new ChatMessage(ChatRole.USER);
76+
fewShotUserMessage.setContent(FEW_SHOT_USER_MESSAGE);
77+
this.conversationHistory.add(fewShotUserMessage);
78+
79+
ChatMessage fewShotAssistantMessage = new ChatMessage(ChatRole.ASSISTANT);
80+
fewShotAssistantMessage.setContent(FEW_SHOT_ASSISTANT_MESSAGE);
81+
this.conversationHistory.add(fewShotAssistantMessage);
82+
}
83+
84+
85+
public List<ChatMessage> getMessages(String question,List<ContentSource> sources ) {
86+
if (sources == null || sources.isEmpty())
87+
throw new IllegalStateException("sources cannot be null or empty");
88+
if (question == null || question.isEmpty())
89+
throw new IllegalStateException("question cannot be null");
90+
91+
StringBuilder sourcesStringBuilder = new StringBuilder();
92+
// Build sources section
93+
sources.iterator().forEachRemaining(source -> sourcesStringBuilder.append(source.getSourceName()).append(": ").append(source.getSourceContent()).append("\n"));
94+
95+
//Add user question with retrieved facts
96+
String groundedUserQuestion = GROUNDED_USER_QUESTION_TEMPLATE.formatted(question,sourcesStringBuilder.toString());
97+
ChatMessage groundedUserMessage = new ChatMessage(ChatRole.USER);
98+
groundedUserMessage.setContent(groundedUserQuestion);
99+
this.conversationHistory.add(groundedUserMessage);
100+
101+
return this.conversationHistory;
102+
}
103+
104+
105+
}
Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
package com.microsoft.openai.samples.rag.ask.approaches;
22

3-
import com.azure.ai.openai.models.Completions;
4-
import com.azure.ai.openai.models.CompletionsOptions;
5-
import com.azure.search.documents.util.SearchPagedIterable;
3+
import com.azure.ai.openai.models.ChatCompletions;
64
import com.microsoft.openai.samples.rag.approaches.ContentSource;
75
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
86
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
97
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
10-
import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy;
8+
import com.microsoft.openai.samples.rag.common.ChatGPTUtils;
119
import com.microsoft.openai.samples.rag.proxy.OpenAIProxy;
10+
import com.microsoft.openai.samples.rag.retrieval.CognitiveSearchRetriever;
11+
import com.microsoft.openai.samples.rag.retrieval.Retriever;
1212
import org.slf4j.Logger;
1313
import org.slf4j.LoggerFactory;
14+
import org.springframework.context.ApplicationContext;
15+
import org.springframework.context.ApplicationContextAware;
1416
import org.springframework.stereotype.Component;
1517

16-
import java.util.ArrayList;
17-
import java.util.Collections;
1818
import java.util.List;
1919

2020
/**
@@ -23,60 +23,76 @@
2323
* (answer) with that prompt.
2424
*/
2525
@Component
26-
public class RetrieveThenReadApproach implements RAGApproach<String, RAGResponse> {
26+
public class RetrieveThenReadApproach implements RAGApproach<String, RAGResponse>, ApplicationContextAware {
2727

2828
private static final Logger LOGGER = LoggerFactory.getLogger(RetrieveThenReadApproach.class);
29-
private final CognitiveSearchProxy cognitiveSearchProxy;
29+
private ApplicationContext applicationContext;
3030
private final OpenAIProxy openAIProxy;
31+
private final Retriever factsRetriever;
3132

32-
public RetrieveThenReadApproach(CognitiveSearchProxy cognitiveSearchProxy, OpenAIProxy openAIProxy) {
33-
this.cognitiveSearchProxy = cognitiveSearchProxy;
33+
public RetrieveThenReadApproach(Retriever factsRetriever, OpenAIProxy openAIProxy) {
34+
this.factsRetriever = factsRetriever;
3435
this.openAIProxy = openAIProxy;
3536
}
3637

3738
/**
38-
* @param questionOrConversation
39+
* @param question
3940
* @param options
4041
* @return
4142
*/
4243
@Override
43-
public RAGResponse run(String questionOrConversation, RAGOptions options) {
44+
public RAGResponse run(String question, RAGOptions options) {
4445
//TODO exception handling
45-
SearchPagedIterable searchResults = getCognitiveSearchResults(questionOrConversation, options);
4646

47-
List<ContentSource> sources = buildSourcesFromSearchResults(options, searchResults);
47+
Retriever factsRetriever = getFactsRetriever(options);
48+
List<ContentSource> sources = factsRetriever.retrieveFromQuestion(question, options);
4849
LOGGER.info("Total {} sources found in cognitive search for keyword search query[{}]", sources.size(),
49-
questionOrConversation);
50+
question);
5051

51-
var retrieveThenReadPrompt = new SemanticSearchAskPrompt(sources, questionOrConversation);
52+
var customPrompt = options.getPromptTemplate();
53+
var customPromptEmpty = (customPrompt == null) || (customPrompt != null && customPrompt.isEmpty());
5254

53-
var completionsOptions = buildCompletionsOptions(retrieveThenReadPrompt);
55+
//true will replace the default prompt. False will add custom prompt as suffix to the default prompt
56+
var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|");
57+
if(!replacePrompt && !customPromptEmpty){
58+
customPrompt = customPrompt.substring(1);
59+
}
5460

55-
Completions completionsResults = openAIProxy.getCompletions(completionsOptions);
61+
var answerQuestionChatTemplate = new AnswerQuestionChatTemplate(customPrompt, replacePrompt);
5662

57-
LOGGER.info("Completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
58-
completionsResults.getUsage().getPromptTokens(),
59-
completionsResults.getUsage().getCompletionTokens(),
60-
completionsResults.getUsage().getTotalTokens());
63+
var groundedChatMessages = answerQuestionChatTemplate.getMessages(question,sources);
64+
var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(groundedChatMessages);
65+
66+
// STEP 3: Generate a contextual and content specific answer using the retrieve facts
67+
ChatCompletions chatCompletions = openAIProxy.getChatCompletions(chatCompletionsOptions);
68+
69+
LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
70+
chatCompletions.getUsage().getPromptTokens(),
71+
chatCompletions.getUsage().getCompletionTokens(),
72+
chatCompletions.getUsage().getTotalTokens());
6173

6274
return new RAGResponse.Builder()
63-
.prompt(retrieveThenReadPrompt.getFormattedPrompt())
64-
.answer(completionsResults.getChoices().get(0).getText())
65-
.sources(sources)
66-
.question(questionOrConversation)
67-
.build();
75+
.question(question)
76+
.prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages))
77+
.answer(chatCompletions.getChoices().get(0).getMessage().getContent())
78+
.sources(sources)
79+
.build();
80+
6881
}
6982

70-
private CompletionsOptions buildCompletionsOptions(SemanticSearchAskPrompt retrieveThenReadPrompt) {
71-
CompletionsOptions completionsOptions = new CompletionsOptions(new ArrayList<>(Collections.singletonList(retrieveThenReadPrompt.getFormattedPrompt())));
72-
// Due to a potential bug when using JVM 17 and java openai SDK 1.0.0-beta.2, we need to provide default for all properties to avoid 404 bad Request on the server
73-
completionsOptions.setStop(List.of("\n"));
74-
return fillCommonCompletionsOptions(completionsOptions);
83+
/**
84+
*
85+
* @param options rag options containing search types(Cognitive Semantic Search, Cognitive Vector Search, Cognitive Hybrid Search, Semantic Kernel Memory) )
86+
* @return retriever implementation
87+
*/
88+
private CognitiveSearchRetriever getFactsRetriever(RAGOptions options) {
89+
//default to Cognitive Semantic Search for MVP.
90+
return this.applicationContext.getBean(CognitiveSearchRetriever.class);
91+
7592
}
7693

77-
@Override
78-
public CognitiveSearchProxy getCognitiveSearchProxy() {
79-
return this.cognitiveSearchProxy;
94+
public void setApplicationContext(ApplicationContext applicationContext) {
95+
this.applicationContext = applicationContext;
8096
}
8197

8298
}

0 commit comments

Comments
 (0)