Skip to content

Commit 57272e5

Browse files
author
Milder Hernandez Cagua
committed
Add SK chat support
1 parent f51e6f5 commit 57272e5

File tree

1 file changed

+46
-9
lines changed

1 file changed

+46
-9
lines changed

app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/approaches/semantickernel/JavaSemanticKernelWithMemoryChatApproach.java

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
1111
import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.memory.CustomAzureCognitiveSearchMemoryStore;
1212
import com.microsoft.openai.samples.rag.common.ChatGPTConversation;
13+
import com.microsoft.openai.samples.rag.common.ChatGPTMessage;
1314
import com.microsoft.openai.samples.rag.common.ChatGPTUtils;
1415
import com.microsoft.semantickernel.Kernel;
1516
import com.microsoft.semantickernel.SKBuilders;
1617
import com.microsoft.semantickernel.ai.embeddings.Embedding;
18+
import com.microsoft.semantickernel.chatcompletion.ChatCompletion;
19+
import com.microsoft.semantickernel.connectors.ai.openai.chatcompletion.OpenAIChatCompletion;
20+
import com.microsoft.semantickernel.connectors.ai.openai.chatcompletion.OpenAIChatHistory;
1721
import com.microsoft.semantickernel.memory.MemoryQueryResult;
1822
import com.microsoft.semantickernel.memory.MemoryRecord;
19-
import com.microsoft.semantickernel.orchestration.SKContext;
2023
import org.slf4j.Logger;
2124
import org.slf4j.LoggerFactory;
2225
import org.springframework.beans.factory.annotation.Value;
@@ -53,6 +56,24 @@ public class JavaSemanticKernelWithMemoryChatApproach implements RAGApproach<Cha
5356
@Value("${openai.embedding.deployment}")
5457
private String embeddingDeploymentModelId;
5558

59+
private static final String FOLLOW_UP_QUESTIONS_TEMPLATE = """
60+
After answering question, also generate three very brief follow-up questions that the user would likely ask next.
61+
Use double angle brackets to reference the questions, e.g. <<Are there exclusions for prescriptions?>>.
62+
Try not to repeat questions that have already been asked.
63+
Only generate questions and do not generate any text before or after the questions, such as 'Next Questions'
64+
""";
65+
private static final String SYSTEM_CHAT_MESSAGE_TEMPLATE = """
66+
Assistant helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers.
67+
Answer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. Do not generate answers that don't use the sources below. If asking a clarifying question to the user would help, ask the question.
68+
For tabular information return it as an html table. Do not return markdown format.
69+
Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. Use square brackets to reference the source, e.g. [info1.txt]. Don't combine sources, list each source separately, e.g. [info1.txt][info2.pdf].
70+
%s
71+
72+
%s
73+
Sources:
74+
%s
75+
""" ;
76+
5677
public JavaSemanticKernelWithMemoryChatApproach(TokenCredential tokenCredential, OpenAIAsyncClient openAIAsyncClient, SearchAsyncClient searchAsyncClient) {
5778
this.tokenCredential = tokenCredential;
5879
this.openAIAsyncClient = openAIAsyncClient;
@@ -61,7 +82,6 @@ public JavaSemanticKernelWithMemoryChatApproach(TokenCredential tokenCredential,
6182

6283
@Override
6384
public RAGResponse run(ChatGPTConversation questionOrConversation, RAGOptions options) {
64-
6585
String question = ChatGPTUtils.getLastUserQuestion(questionOrConversation.getMessages());
6686

6787
//Build semantic kernel with Azure Cognitive Search as memory store. AnswerQuestion skill is imported from resources.
@@ -84,29 +104,46 @@ public RAGResponse run(ChatGPTConversation questionOrConversation, RAGOptions op
84104
String sources = buildSourcesText(memoryResult);
85105
List<ContentSource> sourcesList = buildSources(memoryResult);
86106

87-
SKContext skcontext = SKBuilders.context().build()
88-
.setVariable("sources", sources)
89-
.setVariable("input", question);
107+
// Use ChatCompletion Service to generate a reply
108+
OpenAIChatCompletion chat = (OpenAIChatCompletion) semanticKernel.getService(null, ChatCompletion.class);
109+
OpenAIChatHistory history = buildChatHistory(questionOrConversation, options, chat, sources);
90110

91-
92-
Mono<SKContext> result = semanticKernel.getFunction("RAG", "AnswerQuestion").invokeAsync(skcontext);
111+
Mono<String> reply = chat.generateMessageAsync(history, null);
93112

94113
return new RAGResponse.Builder()
95114
//.prompt(plan.toPlanString())
96115
.prompt("placeholders for prompt")
97-
.answer(result.block().getResult())
116+
.answer(reply.block())
98117
.sources(sourcesList)
99118
.sourcesAsText(sources)
100119
.question(question)
101120
.build();
102-
103121
}
104122

105123
@Override
106124
public void runStreaming(ChatGPTConversation questionOrConversation, RAGOptions options, OutputStream outputStream) {
107125
throw new IllegalStateException("Streaming not supported for this approach");
108126
}
109127

128+
private OpenAIChatHistory buildChatHistory(ChatGPTConversation conversation, RAGOptions options, OpenAIChatCompletion chat,
129+
String sources) {
130+
String systemMessage = SYSTEM_CHAT_MESSAGE_TEMPLATE.formatted(
131+
options.isSuggestFollowupQuestions() ? FOLLOW_UP_QUESTIONS_TEMPLATE : "",
132+
options.getPromptTemplate() != null ? options.getPromptTemplate() : "",
133+
sources);
134+
135+
OpenAIChatHistory chatHistory = chat.createNewChat(systemMessage);
136+
conversation.getMessages().forEach(message -> {
137+
if(message.role() == ChatGPTMessage.ChatRole.USER){
138+
chatHistory.addUserMessage(message.content());
139+
} else if(message.role() == ChatGPTMessage.ChatRole.ASSISTANT) {
140+
chatHistory.addAssistantMessage(message.content());
141+
}
142+
});
143+
144+
return chatHistory;
145+
}
146+
110147
private List<ContentSource> buildSources(List<MemoryQueryResult> memoryResult) {
111148
return memoryResult
112149
.stream()

0 commit comments

Comments
 (0)