Skip to content

Commit f51e6f5

Browse files
committed
store wip
1 parent 86c5ed5 commit f51e6f5

File tree

8 files changed

+420
-23
lines changed

8 files changed

+420
-23
lines changed

app/backend/src/main/java/com/microsoft/openai/samples/rag/Application.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import org.slf4j.LoggerFactory;
55
import org.springframework.boot.SpringApplication;
66
import org.springframework.boot.autoconfigure.SpringBootApplication;
7+
import org.springframework.context.annotation.Bean;
8+
import org.springframework.web.servlet.config.annotation.CorsRegistry;
9+
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
710

811
@SpringBootApplication
912
public class Application {
@@ -14,4 +17,16 @@ public static void main(String[] args) {
1417
LOG.info("Application profile from system property is [{}]", System.getProperty("spring.profiles.active"));
1518
new SpringApplication(Application.class).run(args);
1619
}
20+
21+
@Bean
22+
public WebMvcConfigurer corsConfigurer() {
23+
return new WebMvcConfigurer() {
24+
@Override
25+
public void addCorsMappings(CorsRegistry registry) {
26+
registry
27+
.addMapping("/api/**")
28+
.allowedOrigins("http://localhost:8080");
29+
}
30+
};
31+
}
1732
}

app/backend/src/main/java/com/microsoft/openai/samples/rag/approaches/RAGApproachFactorySpringBootImpl.java

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import com.microsoft.openai.samples.rag.ask.approaches.PlainJavaAskApproach;
44
import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.JavaSemanticKernelChainsApproach;
5-
import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.JavaSemanticKernelWithMemoryApproach;
65
import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.JavaSemanticKernelPlannerApproach;
6+
import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.JavaSemanticKernelWithMemoryApproach;
77
import com.microsoft.openai.samples.rag.chat.approaches.PlainJavaChatApproach;
8+
import com.microsoft.openai.samples.rag.chat.approaches.semantickernel.JavaSemanticKernelChainsChatApproach;
9+
import com.microsoft.openai.samples.rag.chat.approaches.semantickernel.JavaSemanticKernelWithMemoryChatApproach;
810
import org.springframework.context.ApplicationContext;
911
import org.springframework.context.ApplicationContextAware;
1012
import org.springframework.stereotype.Component;
@@ -27,18 +29,27 @@ public class RAGApproachFactorySpringBootImpl implements RAGApproachFactory, App
2729
@Override
2830
public RAGApproach createApproach(String approachName, RAGType ragType, RAGOptions ragOptions) {
2931

30-
if (ragType.equals(RAGType.CHAT) && JAVA_OPENAI_SDK.equals(approachName)) {
31-
return applicationContext.getBean(PlainJavaChatApproach.class);
32-
32+
if (ragType.equals(RAGType.CHAT)) {
33+
if (JAVA_SEMANTIC_KERNEL.equals(approachName)) {
34+
return applicationContext.getBean(JavaSemanticKernelWithMemoryChatApproach.class);
35+
} else if (
36+
JAVA_SEMANTIC_KERNEL_PLANNER.equals(approachName) &&
37+
ragOptions != null &&
38+
ragOptions.getSemantickKernelMode() != null &&
39+
ragOptions.getSemantickKernelMode() == SemanticKernelMode.chains) {
40+
return applicationContext.getBean(JavaSemanticKernelChainsChatApproach.class);
41+
} else {
42+
return applicationContext.getBean(PlainJavaChatApproach.class);
43+
}
3344
} else if (ragType.equals(RAGType.ASK)) {
3445
if (JAVA_OPENAI_SDK.equals(approachName))
3546
return applicationContext.getBean(PlainJavaAskApproach.class);
3647
else if (JAVA_SEMANTIC_KERNEL.equals(approachName))
3748
return applicationContext.getBean(JavaSemanticKernelWithMemoryApproach.class);
3849
else if (JAVA_SEMANTIC_KERNEL_PLANNER.equals(approachName) && ragOptions.getSemantickKernelMode() != null && ragOptions.getSemantickKernelMode() == SemanticKernelMode.planner)
39-
return applicationContext.getBean(JavaSemanticKernelPlannerApproach.class);
40-
else if(JAVA_SEMANTIC_KERNEL_PLANNER.equals(approachName) && ragOptions != null && ragOptions.getSemantickKernelMode() != null && ragOptions.getSemantickKernelMode() == SemanticKernelMode.chains)
41-
return applicationContext.getBean(JavaSemanticKernelChainsApproach.class);
50+
return applicationContext.getBean(JavaSemanticKernelPlannerApproach.class);
51+
else if (JAVA_SEMANTIC_KERNEL_PLANNER.equals(approachName) && ragOptions != null && ragOptions.getSemantickKernelMode() != null && ragOptions.getSemantickKernelMode() == SemanticKernelMode.chains)
52+
return applicationContext.getBean(JavaSemanticKernelChainsApproach.class);
4253

4354
}
4455
//if this point is reached then the combination of approach and rag type is not supported

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelChainsApproach.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,4 @@ private Kernel buildSemanticKernel( RAGOptions options) {
127127

128128
return kernel;
129129
}
130-
131-
132-
133130
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package com.microsoft.openai.samples.rag.chat.approaches.semantickernel;
2+
3+
import com.azure.ai.openai.OpenAIAsyncClient;
4+
import com.microsoft.openai.samples.rag.approaches.ContentSource;
5+
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
6+
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
7+
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
8+
import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.CognitiveSearchPlugin;
9+
import com.microsoft.openai.samples.rag.common.ChatGPTConversation;
10+
import com.microsoft.openai.samples.rag.common.ChatGPTUtils;
11+
import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy;
12+
import com.microsoft.openai.samples.rag.proxy.OpenAIProxy;
13+
import com.microsoft.semantickernel.Kernel;
14+
import com.microsoft.semantickernel.SKBuilders;
15+
import com.microsoft.semantickernel.chatcompletion.ChatCompletion;
16+
import com.microsoft.semantickernel.orchestration.SKContext;
17+
import org.slf4j.Logger;
18+
import org.slf4j.LoggerFactory;
19+
import org.springframework.beans.factory.annotation.Value;
20+
import org.springframework.stereotype.Component;
21+
22+
import java.io.OutputStream;
23+
import java.util.Arrays;
24+
import java.util.Collections;
25+
import java.util.List;
26+
import java.util.Objects;
27+
import java.util.stream.Collectors;
28+
29+
/**
30+
* Simple chat-read-retrieve-read java implementation, using the Cognitive Search and OpenAI APIs directly.
31+
* It uses the ChatGPT API to turn the user question into a good search query.
32+
* It queries Azure Cognitive Search for search results for that query (optionally using the vector embeddings for that query).
33+
* It then combines the search results and original user question, and asks ChatGPT API to answer the question based on the sources. It includes the last 4K of message history as well (or however many tokens are allowed by the deployed model).
34+
*/
35+
@Component
36+
public class JavaSemanticKernelChainsChatApproach implements RAGApproach<ChatGPTConversation, RAGResponse> {
37+
38+
private static final Logger LOGGER = LoggerFactory.getLogger(JavaSemanticKernelChainsChatApproach.class);
39+
private static final String PLAN_PROMPT = """
40+
Take the input as a question and answer it finding any information needed
41+
""";
42+
private final CognitiveSearchProxy cognitiveSearchProxy;
43+
44+
private final OpenAIProxy openAIProxy;
45+
46+
private final OpenAIAsyncClient openAIAsyncClient;
47+
48+
@Value("${openai.chatgpt.deployment}")
49+
private String gptChatDeploymentModelId;
50+
51+
public JavaSemanticKernelChainsChatApproach(CognitiveSearchProxy cognitiveSearchProxy, OpenAIAsyncClient openAIAsyncClient, OpenAIProxy openAIProxy) {
52+
this.cognitiveSearchProxy = cognitiveSearchProxy;
53+
this.openAIAsyncClient = openAIAsyncClient;
54+
this.openAIProxy = openAIProxy;
55+
}
56+
57+
/**
58+
* @param questionOrConversation
59+
* @param options
60+
* @return
61+
*/
62+
@Override
63+
public RAGResponse run(ChatGPTConversation questionOrConversation, RAGOptions options) {
64+
65+
String question = ChatGPTUtils.getLastUserQuestion(questionOrConversation.getMessages());
66+
67+
Kernel semanticKernel = buildSemanticKernel(options);
68+
69+
SKContext searchContext =
70+
semanticKernel.runAsync(
71+
question,
72+
semanticKernel.getSkill("InformationFinder").getFunction("Search", null)).block();
73+
74+
var sources = formSourcesList(searchContext.getResult());
75+
76+
var answerVariables = SKBuilders.variables()
77+
.withVariable("sources", searchContext.getResult())
78+
.withVariable("input", question)
79+
.build();
80+
81+
SKContext answerExecutionContext =
82+
semanticKernel.runAsync(answerVariables,
83+
semanticKernel.getSkill("RAG").getFunction("AnswerQuestion", null)).block();
84+
return new RAGResponse.Builder()
85+
.prompt("Prompt is managed by Semantic Kernel")
86+
.answer(answerExecutionContext.getResult())
87+
.sources(sources)
88+
.sourcesAsText(searchContext.getResult())
89+
.question(question)
90+
.build();
91+
}
92+
93+
@Override
94+
public void runStreaming(
95+
ChatGPTConversation questionOrConversation,
96+
RAGOptions options,
97+
OutputStream outputStream) {
98+
}
99+
100+
private List<ContentSource> formSourcesList(String result) {
101+
if (result == null) {
102+
return Collections.emptyList();
103+
}
104+
return Arrays.stream(result
105+
.split("\n"))
106+
.map(source -> {
107+
String[] split = source.split(":", 2);
108+
if (split.length >= 2) {
109+
var sourceName = split[0].trim();
110+
var sourceContent = split[1].trim();
111+
return new ContentSource(sourceName, sourceContent);
112+
} else {
113+
return null;
114+
}
115+
})
116+
.filter(Objects::nonNull)
117+
.collect(Collectors.toList());
118+
}
119+
120+
private Kernel buildSemanticKernel(RAGOptions options) {
121+
Kernel kernel = SKBuilders.kernel()
122+
.withDefaultAIService(SKBuilders.chatCompletion()
123+
.withModelId(gptChatDeploymentModelId)
124+
.withOpenAIClient(this.openAIAsyncClient)
125+
.build())
126+
.build();
127+
128+
kernel.importSkill(new CognitiveSearchPlugin(this.cognitiveSearchProxy, this.openAIProxy, options), "InformationFinder");
129+
130+
kernel.importSkillFromResources(
131+
"semantickernel/Plugins",
132+
"RAG",
133+
"AnswerQuestion",
134+
null
135+
);
136+
137+
return kernel;
138+
}
139+
140+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
package com.microsoft.openai.samples.rag.chat.approaches.semantickernel;
2+
3+
import com.azure.ai.openai.OpenAIAsyncClient;
4+
import com.azure.core.credential.TokenCredential;
5+
import com.azure.search.documents.SearchAsyncClient;
6+
import com.azure.search.documents.SearchDocument;
7+
import com.microsoft.openai.samples.rag.approaches.ContentSource;
8+
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
9+
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
10+
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
11+
import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.memory.CustomAzureCognitiveSearchMemoryStore;
12+
import com.microsoft.openai.samples.rag.common.ChatGPTConversation;
13+
import com.microsoft.openai.samples.rag.common.ChatGPTUtils;
14+
import com.microsoft.semantickernel.Kernel;
15+
import com.microsoft.semantickernel.SKBuilders;
16+
import com.microsoft.semantickernel.ai.embeddings.Embedding;
17+
import com.microsoft.semantickernel.memory.MemoryQueryResult;
18+
import com.microsoft.semantickernel.memory.MemoryRecord;
19+
import com.microsoft.semantickernel.orchestration.SKContext;
20+
import org.slf4j.Logger;
21+
import org.slf4j.LoggerFactory;
22+
import org.springframework.beans.factory.annotation.Value;
23+
import org.springframework.stereotype.Component;
24+
import reactor.core.publisher.Mono;
25+
26+
import java.io.OutputStream;
27+
import java.util.List;
28+
import java.util.function.Function;
29+
import java.util.stream.Collectors;
30+
31+
/**
32+
* Accomplish the same task as in the PlainJavaAskApproach approach but using Semantic Kernel framework:
33+
* 1. Memory abstraction is used for vector search capability. It uses Azure Cognitive Search as memory store.
34+
* 2. Semantic function has been defined to ask question using sources from memory search results
35+
*/
36+
@Component
37+
public class JavaSemanticKernelWithMemoryChatApproach implements RAGApproach<ChatGPTConversation, RAGResponse> {
38+
private static final Logger LOGGER = LoggerFactory.getLogger(JavaSemanticKernelWithMemoryChatApproach.class);
39+
private final TokenCredential tokenCredential;
40+
private final OpenAIAsyncClient openAIAsyncClient;
41+
42+
private final SearchAsyncClient searchAsyncClient;
43+
44+
private final String EMBEDDING_FIELD_NAME = "embedding";
45+
46+
@Value("${cognitive.search.service}")
47+
String searchServiceName;
48+
@Value("${cognitive.search.index}")
49+
String indexName;
50+
@Value("${openai.chatgpt.deployment}")
51+
private String gptChatDeploymentModelId;
52+
53+
@Value("${openai.embedding.deployment}")
54+
private String embeddingDeploymentModelId;
55+
56+
public JavaSemanticKernelWithMemoryChatApproach(TokenCredential tokenCredential, OpenAIAsyncClient openAIAsyncClient, SearchAsyncClient searchAsyncClient) {
57+
this.tokenCredential = tokenCredential;
58+
this.openAIAsyncClient = openAIAsyncClient;
59+
this.searchAsyncClient = searchAsyncClient;
60+
}
61+
62+
@Override
63+
public RAGResponse run(ChatGPTConversation questionOrConversation, RAGOptions options) {
64+
65+
String question = ChatGPTUtils.getLastUserQuestion(questionOrConversation.getMessages());
66+
67+
//Build semantic kernel with Azure Cognitive Search as memory store. AnswerQuestion skill is imported from resources.
68+
Kernel semanticKernel = buildSemanticKernel(options);
69+
70+
/**
71+
* Use semantic kernel built-in memory.searchAsync. It uses OpenAI to generate embeddings for the provided question.
72+
* Question embeddings are provided to cognitive search via search options.
73+
*/
74+
List<MemoryQueryResult> memoryResult = semanticKernel.getMemory().searchAsync(
75+
indexName,
76+
question,
77+
options.getTop(),
78+
0.5f,
79+
false)
80+
.block();
81+
82+
LOGGER.info("Total {} sources found in cognitive vector store for search query[{}]", memoryResult.size(), question);
83+
84+
String sources = buildSourcesText(memoryResult);
85+
List<ContentSource> sourcesList = buildSources(memoryResult);
86+
87+
SKContext skcontext = SKBuilders.context().build()
88+
.setVariable("sources", sources)
89+
.setVariable("input", question);
90+
91+
92+
Mono<SKContext> result = semanticKernel.getFunction("RAG", "AnswerQuestion").invokeAsync(skcontext);
93+
94+
return new RAGResponse.Builder()
95+
//.prompt(plan.toPlanString())
96+
.prompt("placeholders for prompt")
97+
.answer(result.block().getResult())
98+
.sources(sourcesList)
99+
.sourcesAsText(sources)
100+
.question(question)
101+
.build();
102+
103+
}
104+
105+
@Override
106+
public void runStreaming(ChatGPTConversation questionOrConversation, RAGOptions options, OutputStream outputStream) {
107+
throw new IllegalStateException("Streaming not supported for this approach");
108+
}
109+
110+
private List<ContentSource> buildSources(List<MemoryQueryResult> memoryResult) {
111+
return memoryResult
112+
.stream()
113+
.map(result -> {
114+
return new ContentSource(
115+
result.getMetadata().getId(),
116+
result.getMetadata().getText()
117+
);
118+
})
119+
.collect(Collectors.toList());
120+
}
121+
122+
private String buildSourcesText(List<MemoryQueryResult> memoryResult) {
123+
StringBuilder sourcesContentBuffer = new StringBuilder();
124+
memoryResult.stream().forEach(memory -> {
125+
sourcesContentBuffer.append(memory.getMetadata().getId())
126+
.append(": ")
127+
.append(memory.getMetadata().getText().replace("\n", ""))
128+
.append("\n");
129+
});
130+
return sourcesContentBuffer.toString();
131+
}
132+
133+
private Kernel buildSemanticKernel(RAGOptions options) {
134+
var kernelWithACS = SKBuilders.kernel()
135+
.withMemoryStorage(
136+
new CustomAzureCognitiveSearchMemoryStore("https://%s.search.windows.net".formatted(searchServiceName),
137+
tokenCredential,
138+
this.searchAsyncClient,
139+
this.EMBEDDING_FIELD_NAME,
140+
buildCustomMemoryMapper()))
141+
.withDefaultAIService(SKBuilders.textEmbeddingGeneration()
142+
.withOpenAIClient(openAIAsyncClient)
143+
.withModelId(embeddingDeploymentModelId)
144+
.build())
145+
.withDefaultAIService(SKBuilders.chatCompletion()
146+
.withModelId(gptChatDeploymentModelId)
147+
.withOpenAIClient(this.openAIAsyncClient)
148+
.build())
149+
.build();
150+
151+
kernelWithACS.importSkillFromResources("semantickernel/Plugins", "RAG", "AnswerQuestion", null);
152+
return kernelWithACS;
153+
}
154+
155+
private Function<SearchDocument, MemoryRecord> buildCustomMemoryMapper() {
156+
return searchDocument -> {
157+
return MemoryRecord.localRecord(
158+
(String) searchDocument.get("sourcepage"),
159+
(String) searchDocument.get("content"),
160+
"chunked text from original source",
161+
new Embedding((List<Float>) searchDocument.get(EMBEDDING_FIELD_NAME)),
162+
(String) searchDocument.get("category"),
163+
(String) searchDocument.get("id"),
164+
null);
165+
166+
};
167+
}
168+
}

0 commit comments

Comments
 (0)