Skip to content

Commit a01c7c0

Browse files
committed
Add streaming capability
1 parent fc733e6 commit a01c7c0

File tree

16 files changed

+261
-75
lines changed

16 files changed

+261
-75
lines changed

app/backend/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454
<groupId>org.springframework.boot</groupId>
5555
<artifactId>spring-boot-starter-web</artifactId>
5656
</dependency>
57+
<dependency>
58+
<groupId>org.springframework.boot</groupId>
59+
<artifactId>spring-boot-starter-webflux</artifactId>
60+
</dependency>
5761

5862
<dependency>
5963
<groupId>org.springframework.boot</groupId>

app/backend/src/main/java/com/microsoft/openai/samples/rag/Application.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,4 @@ public static void main(String[] args) {
1414
LOG.info("Application profile from system property is [{}]", System.getProperty("spring.profiles.active"));
1515
new SpringApplication(Application.class).run(args);
1616
}
17-
1817
}
Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
package com.microsoft.openai.samples.rag.approaches;
22

3+
import reactor.core.publisher.Flux;
4+
35
public interface RAGApproach<I, O> {
46

57
O run(I questionOrConversation, RAGOptions options);
6-
7-
8-
9-
10-
8+
Flux<O> runStreaming(I questionOrConversation, RAGOptions options);
119
}

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/PlainJavaAskApproach.java

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,27 @@
11
package com.microsoft.openai.samples.rag.ask.approaches;
22

33
import com.azure.ai.openai.models.ChatCompletions;
4+
import com.azure.ai.openai.models.ChatCompletionsOptions;
5+
import com.azure.ai.openai.models.ChatMessage;
46
import com.microsoft.openai.samples.rag.approaches.ContentSource;
57
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
68
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
79
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
810
import com.microsoft.openai.samples.rag.common.ChatGPTUtils;
9-
import com.microsoft.openai.samples.rag.retrieval.FactsRetrieverProvider;
1011
import com.microsoft.openai.samples.rag.proxy.OpenAIProxy;
12+
import com.microsoft.openai.samples.rag.retrieval.FactsRetrieverProvider;
1113
import com.microsoft.openai.samples.rag.retrieval.Retriever;
1214
import org.slf4j.Logger;
1315
import org.slf4j.LoggerFactory;
1416
import org.springframework.stereotype.Component;
17+
import reactor.core.publisher.Flux;
1518

1619
import java.util.List;
1720

1821
/**
1922
* Simple retrieve-then-read java implementation, using the Cognitive Search and OpenAI APIs directly. It first retrieves
20-
* top documents from search, then constructs a prompt with them, and then uses OpenAI to generate a completion
21-
* (answer) with that prompt.
23+
* top documents from search, then constructs a prompt with them, and then uses OpenAI to generate a completion
24+
* (answer) with that prompt.
2225
*/
2326
@Component
2427
public class PlainJavaAskApproach implements RAGApproach<String, RAGResponse> {
@@ -39,8 +42,65 @@ public PlainJavaAskApproach(FactsRetrieverProvider factsRetrieverProvider, OpenA
3942
*/
4043
@Override
4144
public RAGResponse run(String question, RAGOptions options) {
42-
//TODO exception handling
45+
return formChatCompletionArguments(
46+
question,
47+
options,
48+
(chatCompletionsOptions, groundedChatMessages, sources) -> {
49+
// STEP 3: Generate a contextual and content specific answer using the retrieve facts
50+
ChatCompletions chatCompletions = openAIProxy.getChatCompletions(chatCompletionsOptions);
51+
52+
LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
53+
chatCompletions.getUsage().getPromptTokens(),
54+
chatCompletions.getUsage().getCompletionTokens(),
55+
chatCompletions.getUsage().getTotalTokens());
56+
57+
return new RAGResponse.Builder()
58+
.question(question)
59+
.prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages))
60+
.answer(chatCompletions.getChoices().get(0).getMessage().getContent())
61+
.sources(sources)
62+
.build();
63+
});
64+
}
65+
66+
@Override
67+
public Flux<RAGResponse> runStreaming(String question, RAGOptions options) {
68+
return formChatCompletionArguments(
69+
question,
70+
options,
71+
(chatCompletionsOptions, groundedChatMessages, sources) -> {
72+
Flux<ChatCompletions> completions = Flux.fromIterable(openAIProxy.getChatCompletionsStream(chatCompletionsOptions));
73+
return completions
74+
.flatMap(completion -> {
75+
LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
76+
completion.getUsage().getPromptTokens(),
77+
completion.getUsage().getCompletionTokens(),
78+
completion.getUsage().getTotalTokens());
79+
80+
return Flux.fromIterable(completion.getChoices())
81+
.map(choice -> new RAGResponse.Builder()
82+
.question(question)
83+
.prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages))
84+
.answer(choice.getMessage().getContent())
85+
.sources(sources)
86+
.build());
87+
});
88+
});
89+
}
4390

91+
private interface CompletionFunction<T> {
92+
T apply(
93+
ChatCompletionsOptions chatCompletionsOptions,
94+
List<ChatMessage> groundedChatMessages,
95+
List<ContentSource> sources
96+
);
97+
}
98+
99+
private <T> T formChatCompletionArguments(
100+
String question,
101+
RAGOptions options,
102+
CompletionFunction<T> completionFunction
103+
) {
44104
//Get instance of retriever based on the retrieval mode: hybryd, text, vectors.
45105
Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options);
46106
List<ContentSource> sources = factsRetriever.retrieveFromQuestion(question, options);
@@ -51,31 +111,17 @@ public RAGResponse run(String question, RAGOptions options) {
51111
var customPromptEmpty = (customPrompt == null) || (customPrompt != null && customPrompt.isEmpty());
52112

53113
//true will replace the default prompt. False will add custom prompt as suffix to the default prompt
54-
var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|");
55-
if(!replacePrompt && !customPromptEmpty){
114+
var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|");
115+
if (!replacePrompt && !customPromptEmpty) {
56116
customPrompt = customPrompt.substring(1);
57117
}
58118

59119
var answerQuestionChatTemplate = new AnswerQuestionChatTemplate(customPrompt, replacePrompt);
60120

61-
var groundedChatMessages = answerQuestionChatTemplate.getMessages(question,sources);
121+
var groundedChatMessages = answerQuestionChatTemplate.getMessages(question, sources);
62122
var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(groundedChatMessages);
63123

64-
// STEP 3: Generate a contextual and content specific answer using the retrieve facts
65-
ChatCompletions chatCompletions = openAIProxy.getChatCompletions(chatCompletionsOptions);
66-
67-
LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
68-
chatCompletions.getUsage().getPromptTokens(),
69-
chatCompletions.getUsage().getCompletionTokens(),
70-
chatCompletions.getUsage().getTotalTokens());
71-
72-
return new RAGResponse.Builder()
73-
.question(question)
74-
.prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages))
75-
.answer(chatCompletions.getChoices().get(0).getMessage().getContent())
76-
.sources(sources)
77-
.build();
78-
124+
return completionFunction.apply(chatCompletionsOptions, groundedChatMessages, sources);
79125
}
80126

81127

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelChainsApproach.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.slf4j.LoggerFactory;
1515
import org.springframework.beans.factory.annotation.Value;
1616
import org.springframework.stereotype.Component;
17+
import reactor.core.publisher.Flux;
1718

1819
import java.util.Arrays;
1920
import java.util.Collections;
@@ -81,6 +82,11 @@ public RAGResponse run(String question, RAGOptions options) {
8182

8283
}
8384

85+
@Override
86+
public Flux<RAGResponse> runStreaming(String questionOrConversation, RAGOptions options) {
87+
return Flux.error(new IllegalStateException("Streaming not supported for this approach"));
88+
}
89+
8490
private List<ContentSource> formSourcesList(String result) {
8591
if (result == null) {
8692
return Collections.emptyList();

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelPlannerApproach.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.slf4j.LoggerFactory;
1616
import org.springframework.beans.factory.annotation.Value;
1717
import org.springframework.stereotype.Component;
18+
import reactor.core.publisher.Flux;
1819

1920
import java.util.Objects;
2021
import java.util.Set;
@@ -80,6 +81,11 @@ public RAGResponse run(String question, RAGOptions options) {
8081

8182
}
8283

84+
@Override
85+
public Flux<RAGResponse> runStreaming(String questionOrConversation, RAGOptions options) {
86+
return Flux.error(new IllegalStateException("Streaming not supported for this approach"));
87+
}
88+
8389
private Kernel buildSemanticKernel( RAGOptions options) {
8490
Kernel kernel = SKBuilders.kernel()
8591
.withDefaultAIService(SKBuilders.chatCompletion()

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelWithMemoryApproach.java

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@
1919
import org.slf4j.LoggerFactory;
2020
import org.springframework.beans.factory.annotation.Value;
2121
import org.springframework.stereotype.Component;
22+
import reactor.core.publisher.Flux;
2223
import reactor.core.publisher.Mono;
2324

2425
import java.util.List;
2526
import java.util.function.Function;
2627
import java.util.stream.Collectors;
2728

2829
/**
29-
* Accomplish the same task as in the PlainJavaAskApproach approach but using Semantic Kernel framework:
30-
* 1. Memory abstraction is used for vector search capability. It uses Azure Cognitive Search as memory store.
31-
* 2. Semantic function has been defined to ask question using sources from memory search results
30+
* Accomplish the same task as in the PlainJavaAskApproach approach but using Semantic Kernel framework:
31+
* 1. Memory abstraction is used for vector search capability. It uses Azure Cognitive Search as memory store.
32+
* 2. Semantic function has been defined to ask question using sources from memory search results
3233
*/
3334
@Component
3435
public class JavaSemanticKernelWithMemoryApproach implements RAGApproach<String, RAGResponse> {
@@ -40,8 +41,10 @@ public class JavaSemanticKernelWithMemoryApproach implements RAGApproach<String,
4041

4142
private final String EMBEDDING_FIELD_NAME = "embedding";
4243

43-
@Value("${cognitive.search.service}") String searchServiceName ;
44-
@Value("${cognitive.search.index}") String indexName;
44+
@Value("${cognitive.search.service}")
45+
String searchServiceName;
46+
@Value("${cognitive.search.index}")
47+
String indexName;
4548
@Value("${openai.chatgpt.deployment}")
4649
private String gptChatDeploymentModelId;
4750

@@ -70,11 +73,11 @@ public RAGResponse run(String question, RAGOptions options) {
7073
* Question embeddings are provided to cognitive search via search options.
7174
*/
7275
List<MemoryQueryResult> memoryResult = semanticKernel.getMemory().searchAsync(
73-
indexName,
74-
question,
75-
options.getTop(),
76-
0.5f,
77-
false)
76+
indexName,
77+
question,
78+
options.getTop(),
79+
0.5f,
80+
false)
7881
.block();
7982

8083
LOGGER.info("Total {} sources found in cognitive vector store for search query[{}]", memoryResult.size(), question);
@@ -90,14 +93,19 @@ public RAGResponse run(String question, RAGOptions options) {
9093
Mono<SKContext> result = semanticKernel.getFunction("RAG", "AnswerQuestion").invokeAsync(skcontext);
9194

9295
return new RAGResponse.Builder()
93-
//.prompt(plan.toPlanString())
94-
.prompt("placeholders for prompt")
95-
.answer(result.block().getResult())
96-
.sources(sourcesList)
97-
.sourcesAsText(sources)
98-
.question(question)
99-
.build();
96+
//.prompt(plan.toPlanString())
97+
.prompt("placeholders for prompt")
98+
.answer(result.block().getResult())
99+
.sources(sourcesList)
100+
.sourcesAsText(sources)
101+
.question(question)
102+
.build();
103+
104+
}
100105

106+
@Override
107+
public Flux<RAGResponse> runStreaming(String questionOrConversation, RAGOptions options) {
108+
return Flux.error(new IllegalStateException("Streaming not supported for this approach"));
101109
}
102110

103111
private List<ContentSource> buildSources(List<MemoryQueryResult> memoryResult) {
@@ -123,15 +131,14 @@ private String buildSourcesText(List<MemoryQueryResult> memoryResult) {
123131
return sourcesContentBuffer.toString();
124132
}
125133

126-
private Kernel buildSemanticKernel( RAGOptions options) {
127-
134+
private Kernel buildSemanticKernel(RAGOptions options) {
128135
var kernelWithACS = SKBuilders.kernel()
129136
.withMemoryStorage(
130137
new CustomAzureCognitiveSearchMemoryStore("https://%s.search.windows.net".formatted(searchServiceName),
131-
tokenCredential,
132-
this.searchAsyncClient,
133-
this.EMBEDDING_FIELD_NAME,
134-
buildCustomMemoryMapper()))
138+
tokenCredential,
139+
this.searchAsyncClient,
140+
this.EMBEDDING_FIELD_NAME,
141+
buildCustomMemoryMapper()))
135142
.withDefaultAIService(SKBuilders.textEmbeddingGeneration()
136143
.withOpenAIClient(openAIAsyncClient)
137144
.withModelId(embeddingDeploymentModelId)
@@ -142,14 +149,13 @@ private Kernel buildSemanticKernel( RAGOptions options) {
142149
.build())
143150
.build();
144151

145-
kernelWithACS.importSkillFromResources("semantickernel/Plugins","RAG","AnswerQuestion",null);
146-
return kernelWithACS;
152+
kernelWithACS.importSkillFromResources("semantickernel/Plugins", "RAG", "AnswerQuestion", null);
153+
return kernelWithACS;
147154
}
148155

149-
150-
private Function<SearchDocument, MemoryRecord> buildCustomMemoryMapper(){
156+
private Function<SearchDocument, MemoryRecord> buildCustomMemoryMapper() {
151157
return searchDocument -> {
152-
return MemoryRecord.localRecord(
158+
return MemoryRecord.localRecord(
153159
(String) searchDocument.get("sourcepage"),
154160
(String) searchDocument.get("content"),
155161
"chunked text from original source",

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/controller/AskController.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
import com.microsoft.openai.samples.rag.approaches.RAGType;
88
import com.microsoft.openai.samples.rag.controller.ChatAppRequest;
99
import com.microsoft.openai.samples.rag.controller.ChatResponse;
10-
import com.microsoft.openai.samples.rag.controller.ResponseChoice;
11-
import com.microsoft.openai.samples.rag.controller.ResponseContext;
12-
import com.microsoft.openai.samples.rag.controller.ResponseMessage;
13-
import com.microsoft.openai.samples.rag.common.ChatGPTMessage;
1410
import org.slf4j.Logger;
1511
import org.slf4j.LoggerFactory;
1612
import org.springframework.http.HttpStatus;
@@ -20,9 +16,6 @@
2016
import org.springframework.web.bind.annotation.RequestBody;
2117
import org.springframework.web.bind.annotation.RestController;
2218

23-
import java.util.Collections;
24-
import java.util.List;
25-
2619
@RestController
2720
public class AskController {
2821

@@ -60,8 +53,7 @@ public ResponseEntity<ChatResponse> openAIAsk(@RequestBody ChatAppRequest askReq
6053

6154
RAGApproach<String, RAGResponse> ragApproach = ragApproachFactory.createApproach(askRequest.approach(), RAGType.ASK, ragOptions);
6255

56+
6357
return ResponseEntity.ok(ChatResponse.buildChatResponse(ragApproach.run(question, ragOptions)));
6458
}
65-
66-
6759
}

0 commit comments

Comments
 (0)