Skip to content

Commit 4296b62

Browse files
authored
Merge pull request #38 from johnoliver/streaming-3
Add streaming capability
2 parents fc733e6 + 86c5ed5 commit 4296b62

File tree

18 files changed

+380
-66
lines changed

18 files changed

+380
-66
lines changed

app/backend/pom.xml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@
5555
<artifactId>spring-boot-starter-web</artifactId>
5656
</dependency>
5757

58+
<dependency>
59+
<groupId>com.fasterxml.jackson.core</groupId>
60+
<artifactId>jackson-databind</artifactId>
61+
<scope>compile</scope>
62+
</dependency>
63+
<dependency>
64+
<groupId>com.fasterxml.jackson.core</groupId>
65+
<artifactId>jackson-core</artifactId>
66+
<scope>compile</scope>
67+
</dependency>
68+
5869
<dependency>
5970
<groupId>org.springframework.boot</groupId>
6071
<artifactId>spring-boot-starter-test</artifactId>

app/backend/src/main/java/com/microsoft/openai/samples/rag/Application.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,4 @@ public static void main(String[] args) {
1414
LOG.info("Application profile from system property is [{}]", System.getProperty("spring.profiles.active"));
1515
new SpringApplication(Application.class).run(args);
1616
}
17-
1817
}
Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package com.microsoft.openai.samples.rag.approaches;
22

3-
public interface RAGApproach<I, O> {
4-
5-
O run(I questionOrConversation, RAGOptions options);
6-
7-
3+
import com.microsoft.openai.samples.rag.common.ChatGPTConversation;
4+
import reactor.core.publisher.Flux;
85

6+
import java.io.OutputStream;
97

8+
public interface RAGApproach<I, O> {
109

10+
O run(I questionOrConversation, RAGOptions options);
11+
void runStreaming(I questionOrConversation, RAGOptions options, OutputStream outputStream);
1112
}

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/PlainJavaAskApproach.java

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,43 @@
11
package com.microsoft.openai.samples.rag.ask.approaches;
22

3+
import com.azure.ai.openai.models.ChatChoice;
34
import com.azure.ai.openai.models.ChatCompletions;
5+
import com.azure.core.util.IterableStream;
6+
import com.fasterxml.jackson.databind.ObjectMapper;
47
import com.microsoft.openai.samples.rag.approaches.ContentSource;
58
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
69
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
710
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
811
import com.microsoft.openai.samples.rag.common.ChatGPTUtils;
9-
import com.microsoft.openai.samples.rag.retrieval.FactsRetrieverProvider;
12+
import com.microsoft.openai.samples.rag.controller.ChatResponse;
1013
import com.microsoft.openai.samples.rag.proxy.OpenAIProxy;
14+
import com.microsoft.openai.samples.rag.retrieval.FactsRetrieverProvider;
1115
import com.microsoft.openai.samples.rag.retrieval.Retriever;
1216
import org.slf4j.Logger;
1317
import org.slf4j.LoggerFactory;
1418
import org.springframework.stereotype.Component;
1519

20+
import java.io.IOException;
21+
import java.io.OutputStream;
1622
import java.util.List;
1723

1824
/**
1925
* Simple retrieve-then-read java implementation, using the Cognitive Search and OpenAI APIs directly. It first retrieves
20-
* top documents from search, then constructs a prompt with them, and then uses OpenAI to generate a completion
21-
* (answer) with that prompt.
26+
* top documents from search, then constructs a prompt with them, and then uses OpenAI to generate a completion
27+
* (answer) with that prompt.
2228
*/
2329
@Component
2430
public class PlainJavaAskApproach implements RAGApproach<String, RAGResponse> {
2531

2632
private static final Logger LOGGER = LoggerFactory.getLogger(PlainJavaAskApproach.class);
2733
private final OpenAIProxy openAIProxy;
2834
private final FactsRetrieverProvider factsRetrieverProvider;
35+
private final ObjectMapper objectMapper;
2936

30-
public PlainJavaAskApproach(FactsRetrieverProvider factsRetrieverProvider, OpenAIProxy openAIProxy) {
37+
public PlainJavaAskApproach(FactsRetrieverProvider factsRetrieverProvider, OpenAIProxy openAIProxy, ObjectMapper objectMapper) {
3138
this.factsRetrieverProvider = factsRetrieverProvider;
3239
this.openAIProxy = openAIProxy;
40+
this.objectMapper = objectMapper;
3341
}
3442

3543
/**
@@ -39,8 +47,6 @@ public PlainJavaAskApproach(FactsRetrieverProvider factsRetrieverProvider, OpenA
3947
*/
4048
@Override
4149
public RAGResponse run(String question, RAGOptions options) {
42-
//TODO exception handling
43-
4450
//Get instance of retriever based on the retrieval mode: hybryd, text, vectors.
4551
Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options);
4652
List<ContentSource> sources = factsRetriever.retrieveFromQuestion(question, options);
@@ -51,14 +57,14 @@ public RAGResponse run(String question, RAGOptions options) {
5157
var customPromptEmpty = (customPrompt == null) || (customPrompt != null && customPrompt.isEmpty());
5258

5359
//true will replace the default prompt. False will add custom prompt as suffix to the default prompt
54-
var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|");
55-
if(!replacePrompt && !customPromptEmpty){
60+
var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|");
61+
if (!replacePrompt && !customPromptEmpty) {
5662
customPrompt = customPrompt.substring(1);
5763
}
5864

5965
var answerQuestionChatTemplate = new AnswerQuestionChatTemplate(customPrompt, replacePrompt);
6066

61-
var groundedChatMessages = answerQuestionChatTemplate.getMessages(question,sources);
67+
var groundedChatMessages = answerQuestionChatTemplate.getMessages(question, sources);
6268
var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(groundedChatMessages);
6369

6470
// STEP 3: Generate a contextual and content specific answer using the retrieve facts
@@ -75,8 +81,67 @@ public RAGResponse run(String question, RAGOptions options) {
7581
.answer(chatCompletions.getChoices().get(0).getMessage().getContent())
7682
.sources(sources)
7783
.build();
78-
7984
}
8085

86+
@Override
87+
public void runStreaming(String question, RAGOptions options, OutputStream outputStream) {
88+
//Get instance of retriever based on the retrieval mode: hybryd, text, vectors.
89+
Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options);
90+
List<ContentSource> sources = factsRetriever.retrieveFromQuestion(question, options);
91+
LOGGER.info("Total {} sources found in cognitive search for keyword search query[{}]", sources.size(),
92+
question);
8193

94+
var customPrompt = options.getPromptTemplate();
95+
var customPromptEmpty = (customPrompt == null) || (customPrompt != null && customPrompt.isEmpty());
96+
97+
//true will replace the default prompt. False will add custom prompt as suffix to the default prompt
98+
var replacePrompt = !customPromptEmpty && !customPrompt.startsWith("|");
99+
if (!replacePrompt && !customPromptEmpty) {
100+
customPrompt = customPrompt.substring(1);
101+
}
102+
103+
var answerQuestionChatTemplate = new AnswerQuestionChatTemplate(customPrompt, replacePrompt);
104+
105+
var groundedChatMessages = answerQuestionChatTemplate.getMessages(question, sources);
106+
var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(groundedChatMessages);
107+
108+
IterableStream<ChatCompletions> completions = openAIProxy.getChatCompletionsStream(chatCompletionsOptions);
109+
int index = 0;
110+
for (ChatCompletions completion : completions) {
111+
112+
LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
113+
completion.getUsage().getPromptTokens(),
114+
completion.getUsage().getCompletionTokens(),
115+
completion.getUsage().getTotalTokens());
116+
117+
for (ChatChoice choice : completion.getChoices()) {
118+
if (choice.getDelta().getContent() == null) {
119+
continue;
120+
}
121+
122+
RAGResponse ragResponse = new RAGResponse.Builder()
123+
.question(question)
124+
.prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages))
125+
.answer(choice.getMessage().getContent())
126+
.sources(sources)
127+
.build();
128+
129+
ChatResponse response;
130+
if (index == 0) {
131+
response = ChatResponse.buildChatResponse(ragResponse);
132+
} else {
133+
response = ChatResponse.buildChatDeltaResponse(index, ragResponse);
134+
}
135+
index++;
136+
137+
try {
138+
String value = objectMapper.writeValueAsString(response) + "\n";
139+
outputStream.write(value.getBytes());
140+
outputStream.flush();
141+
} catch (IOException e) {
142+
throw new RuntimeException(e);
143+
}
144+
}
145+
}
146+
}
82147
}

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelChainsApproach.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
import org.slf4j.LoggerFactory;
1515
import org.springframework.beans.factory.annotation.Value;
1616
import org.springframework.stereotype.Component;
17+
import reactor.core.publisher.Flux;
1718

19+
import java.io.OutputStream;
1820
import java.util.Arrays;
1921
import java.util.Collections;
2022
import java.util.List;
@@ -81,6 +83,11 @@ public RAGResponse run(String question, RAGOptions options) {
8183

8284
}
8385

86+
@Override
87+
public void runStreaming(String questionOrConversation, RAGOptions options, OutputStream outputStream) {
88+
throw new IllegalStateException("Streaming not supported for this approach");
89+
}
90+
8491
private List<ContentSource> formSourcesList(String result) {
8592
if (result == null) {
8693
return Collections.emptyList();

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelPlannerApproach.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
import org.slf4j.LoggerFactory;
1616
import org.springframework.beans.factory.annotation.Value;
1717
import org.springframework.stereotype.Component;
18+
import reactor.core.publisher.Flux;
1819

20+
import java.io.OutputStream;
1921
import java.util.Objects;
2022
import java.util.Set;
2123

@@ -80,6 +82,11 @@ public RAGResponse run(String question, RAGOptions options) {
8082

8183
}
8284

85+
@Override
86+
public void runStreaming(String questionOrConversation, RAGOptions options, OutputStream outputStream) {
87+
throw new IllegalStateException("Streaming not supported for this approach");
88+
}
89+
8390
private Kernel buildSemanticKernel( RAGOptions options) {
8491
Kernel kernel = SKBuilders.kernel()
8592
.withDefaultAIService(SKBuilders.chatCompletion()

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelWithMemoryApproach.java

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,18 @@
1919
import org.slf4j.LoggerFactory;
2020
import org.springframework.beans.factory.annotation.Value;
2121
import org.springframework.stereotype.Component;
22+
import reactor.core.publisher.Flux;
2223
import reactor.core.publisher.Mono;
2324

25+
import java.io.OutputStream;
2426
import java.util.List;
2527
import java.util.function.Function;
2628
import java.util.stream.Collectors;
2729

2830
/**
29-
* Accomplish the same task as in the PlainJavaAskApproach approach but using Semantic Kernel framework:
30-
* 1. Memory abstraction is used for vector search capability. It uses Azure Cognitive Search as memory store.
31-
* 2. Semantic function has been defined to ask question using sources from memory search results
31+
* Accomplish the same task as in the PlainJavaAskApproach approach but using Semantic Kernel framework:
32+
* 1. Memory abstraction is used for vector search capability. It uses Azure Cognitive Search as memory store.
33+
* 2. Semantic function has been defined to ask question using sources from memory search results
3234
*/
3335
@Component
3436
public class JavaSemanticKernelWithMemoryApproach implements RAGApproach<String, RAGResponse> {
@@ -40,8 +42,10 @@ public class JavaSemanticKernelWithMemoryApproach implements RAGApproach<String,
4042

4143
private final String EMBEDDING_FIELD_NAME = "embedding";
4244

43-
@Value("${cognitive.search.service}") String searchServiceName ;
44-
@Value("${cognitive.search.index}") String indexName;
45+
@Value("${cognitive.search.service}")
46+
String searchServiceName;
47+
@Value("${cognitive.search.index}")
48+
String indexName;
4549
@Value("${openai.chatgpt.deployment}")
4650
private String gptChatDeploymentModelId;
4751

@@ -70,11 +74,11 @@ public RAGResponse run(String question, RAGOptions options) {
7074
* Question embeddings are provided to cognitive search via search options.
7175
*/
7276
List<MemoryQueryResult> memoryResult = semanticKernel.getMemory().searchAsync(
73-
indexName,
74-
question,
75-
options.getTop(),
76-
0.5f,
77-
false)
77+
indexName,
78+
question,
79+
options.getTop(),
80+
0.5f,
81+
false)
7882
.block();
7983

8084
LOGGER.info("Total {} sources found in cognitive vector store for search query[{}]", memoryResult.size(), question);
@@ -90,14 +94,19 @@ public RAGResponse run(String question, RAGOptions options) {
9094
Mono<SKContext> result = semanticKernel.getFunction("RAG", "AnswerQuestion").invokeAsync(skcontext);
9195

9296
return new RAGResponse.Builder()
93-
//.prompt(plan.toPlanString())
94-
.prompt("placeholders for prompt")
95-
.answer(result.block().getResult())
96-
.sources(sourcesList)
97-
.sourcesAsText(sources)
98-
.question(question)
99-
.build();
97+
//.prompt(plan.toPlanString())
98+
.prompt("placeholders for prompt")
99+
.answer(result.block().getResult())
100+
.sources(sourcesList)
101+
.sourcesAsText(sources)
102+
.question(question)
103+
.build();
104+
105+
}
100106

107+
@Override
108+
public void runStreaming(String questionOrConversation, RAGOptions options, OutputStream outputStream) {
109+
throw new IllegalStateException("Streaming not supported for this approach");
101110
}
102111

103112
private List<ContentSource> buildSources(List<MemoryQueryResult> memoryResult) {
@@ -123,15 +132,14 @@ private String buildSourcesText(List<MemoryQueryResult> memoryResult) {
123132
return sourcesContentBuffer.toString();
124133
}
125134

126-
private Kernel buildSemanticKernel( RAGOptions options) {
127-
135+
private Kernel buildSemanticKernel(RAGOptions options) {
128136
var kernelWithACS = SKBuilders.kernel()
129137
.withMemoryStorage(
130138
new CustomAzureCognitiveSearchMemoryStore("https://%s.search.windows.net".formatted(searchServiceName),
131-
tokenCredential,
132-
this.searchAsyncClient,
133-
this.EMBEDDING_FIELD_NAME,
134-
buildCustomMemoryMapper()))
139+
tokenCredential,
140+
this.searchAsyncClient,
141+
this.EMBEDDING_FIELD_NAME,
142+
buildCustomMemoryMapper()))
135143
.withDefaultAIService(SKBuilders.textEmbeddingGeneration()
136144
.withOpenAIClient(openAIAsyncClient)
137145
.withModelId(embeddingDeploymentModelId)
@@ -142,14 +150,13 @@ private Kernel buildSemanticKernel( RAGOptions options) {
142150
.build())
143151
.build();
144152

145-
kernelWithACS.importSkillFromResources("semantickernel/Plugins","RAG","AnswerQuestion",null);
146-
return kernelWithACS;
153+
kernelWithACS.importSkillFromResources("semantickernel/Plugins", "RAG", "AnswerQuestion", null);
154+
return kernelWithACS;
147155
}
148156

149-
150-
private Function<SearchDocument, MemoryRecord> buildCustomMemoryMapper(){
157+
private Function<SearchDocument, MemoryRecord> buildCustomMemoryMapper() {
151158
return searchDocument -> {
152-
return MemoryRecord.localRecord(
159+
return MemoryRecord.localRecord(
153160
(String) searchDocument.get("sourcepage"),
154161
(String) searchDocument.get("content"),
155162
"chunked text from original source",

0 commit comments

Comments
 (0)