Skip to content

Commit 6aac2a7

Browse files
committed
Convert to non-reactive
1 parent 4da1711 commit 6aac2a7

File tree

8 files changed

+107
-53
lines changed

8 files changed

+107
-53
lines changed

app/backend/pom.xml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@
5555
<artifactId>spring-boot-starter-web</artifactId>
5656
</dependency>
5757

58+
<dependency>
59+
<groupId>com.fasterxml.jackson.core</groupId>
60+
<artifactId>jackson-databind</artifactId>
61+
</dependency>
62+
<dependency>
63+
<groupId>com.fasterxml.jackson.core</groupId>
64+
<artifactId>jackson-core</artifactId>
65+
</dependency>
66+
5867
<dependency>
5968
<groupId>org.springframework.boot</groupId>
6069
<artifactId>spring-boot-starter-test</artifactId>
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package com.microsoft.openai.samples.rag.approaches;
22

3+
import com.microsoft.openai.samples.rag.common.ChatGPTConversation;
34
import reactor.core.publisher.Flux;
45

6+
import java.io.OutputStream;
7+
58
public interface RAGApproach<I, O> {
69

710
O run(I questionOrConversation, RAGOptions options);
8-
Flux<O> runStreaming(I questionOrConversation, RAGOptions options);
11+
void runStreaming(I questionOrConversation, RAGOptions options, OutputStream outputStream);
912
}

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/PlainJavaAskApproach.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.springframework.stereotype.Component;
1515
import reactor.core.publisher.Flux;
1616

17+
import java.io.OutputStream;
1718
import java.util.List;
1819

1920
/**
@@ -76,6 +77,11 @@ public RAGResponse run(String question, RAGOptions options) {
7677
.build();
7778
}
7879

80+
@Override
81+
public void runStreaming(String questionOrConversation, RAGOptions options, OutputStream outputStream) {
82+
throw new UnsupportedOperationException("Streaming not supported for PlainJavaAskApproach");
83+
}
84+
/*
7985
@Override
8086
public Flux<RAGResponse> runStreaming(String question, RAGOptions options) {
8187
@@ -116,4 +122,6 @@ public Flux<RAGResponse> runStreaming(String question, RAGOptions options) {
116122
.build());
117123
});
118124
}
125+
126+
*/
119127
}

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelChainsApproach.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.springframework.stereotype.Component;
1717
import reactor.core.publisher.Flux;
1818

19+
import java.io.OutputStream;
1920
import java.util.Arrays;
2021
import java.util.Collections;
2122
import java.util.List;
@@ -83,8 +84,8 @@ public RAGResponse run(String question, RAGOptions options) {
8384
}
8485

8586
@Override
86-
public Flux<RAGResponse> runStreaming(String questionOrConversation, RAGOptions options) {
87-
return Flux.error(new IllegalStateException("Streaming not supported for this approach"));
87+
public void runStreaming(String questionOrConversation, RAGOptions options, OutputStream outputStream) {
88+
throw new IllegalStateException("Streaming not supported for this approach");
8889
}
8990

9091
private List<ContentSource> formSourcesList(String result) {

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelPlannerApproach.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.springframework.stereotype.Component;
1818
import reactor.core.publisher.Flux;
1919

20+
import java.io.OutputStream;
2021
import java.util.Objects;
2122
import java.util.Set;
2223

@@ -82,8 +83,8 @@ public RAGResponse run(String question, RAGOptions options) {
8283
}
8384

8485
@Override
85-
public Flux<RAGResponse> runStreaming(String questionOrConversation, RAGOptions options) {
86-
return Flux.error(new IllegalStateException("Streaming not supported for this approach"));
86+
public void runStreaming(String questionOrConversation, RAGOptions options, OutputStream outputStream) {
87+
throw new IllegalStateException("Streaming not supported for this approach");
8788
}
8889

8990
private Kernel buildSemanticKernel( RAGOptions options) {

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/approaches/semantickernel/JavaSemanticKernelWithMemoryApproach.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import reactor.core.publisher.Flux;
2323
import reactor.core.publisher.Mono;
2424

25+
import java.io.OutputStream;
2526
import java.util.List;
2627
import java.util.function.Function;
2728
import java.util.stream.Collectors;
@@ -104,8 +105,8 @@ public RAGResponse run(String question, RAGOptions options) {
104105
}
105106

106107
@Override
107-
public Flux<RAGResponse> runStreaming(String questionOrConversation, RAGOptions options) {
108-
return Flux.error(new IllegalStateException("Streaming not supported for this approach"));
108+
public void runStreaming(String questionOrConversation, RAGOptions options, OutputStream outputStream) {
109+
throw new IllegalStateException("Streaming not supported for this approach");
109110
}
110111

111112
private List<ContentSource> buildSources(List<MemoryQueryResult> memoryResult) {

app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/approaches/PlainJavaChatApproach.java

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
11
package com.microsoft.openai.samples.rag.chat.approaches;
22

3+
import com.azure.ai.openai.models.ChatChoice;
34
import com.azure.ai.openai.models.ChatCompletions;
5+
import com.azure.core.util.IterableStream;
6+
import com.fasterxml.jackson.databind.ObjectMapper;
47
import com.microsoft.openai.samples.rag.approaches.ContentSource;
58
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
69
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
710
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
811
import com.microsoft.openai.samples.rag.common.ChatGPTConversation;
912
import com.microsoft.openai.samples.rag.common.ChatGPTUtils;
13+
import com.microsoft.openai.samples.rag.controller.ChatResponse;
1014
import com.microsoft.openai.samples.rag.proxy.OpenAIProxy;
1115
import com.microsoft.openai.samples.rag.retrieval.FactsRetrieverProvider;
1216
import com.microsoft.openai.samples.rag.retrieval.Retriever;
1317
import org.slf4j.Logger;
1418
import org.slf4j.LoggerFactory;
1519
import org.springframework.context.ApplicationContext;
1620
import org.springframework.stereotype.Component;
17-
import reactor.core.publisher.Flux;
1821

22+
import java.io.IOException;
23+
import java.io.OutputStream;
1924
import java.util.List;
25+
import java.util.concurrent.atomic.AtomicInteger;
2026

2127
/**
2228
* Simple chat-read-retrieve-read java implementation, using the Cognitive Search and OpenAI APIs directly.
@@ -28,13 +34,18 @@
2834
public class PlainJavaChatApproach implements RAGApproach<ChatGPTConversation, RAGResponse> {
2935

3036
private static final Logger LOGGER = LoggerFactory.getLogger(PlainJavaChatApproach.class);
37+
private final ObjectMapper objectMapper;
3138
private ApplicationContext applicationContext;
3239
private final OpenAIProxy openAIProxy;
3340
private final FactsRetrieverProvider factsRetrieverProvider;
3441

35-
public PlainJavaChatApproach(FactsRetrieverProvider factsRetrieverProvider, OpenAIProxy openAIProxy) {
42+
public PlainJavaChatApproach(
43+
FactsRetrieverProvider factsRetrieverProvider,
44+
OpenAIProxy openAIProxy,
45+
ObjectMapper objectMapper) {
3646
this.factsRetrieverProvider = factsRetrieverProvider;
3747
this.openAIProxy = openAIProxy;
48+
this.objectMapper = objectMapper;
3849
}
3950

4051
/**
@@ -71,43 +82,60 @@ public RAGResponse run(ChatGPTConversation questionOrConversation, RAGOptions op
7182
}
7283

7384
@Override
74-
public Flux<RAGResponse> runStreaming(ChatGPTConversation questionOrConversation, RAGOptions options) {
75-
85+
public void runStreaming(
86+
ChatGPTConversation questionOrConversation,
87+
RAGOptions options,
88+
OutputStream outputStream) {
7689
Retriever factsRetriever = factsRetrieverProvider.getFactsRetriever(options);
7790
List<ContentSource> sources = factsRetriever.retrieveFromConversation(questionOrConversation, options);
7891
LOGGER.info("Total {} sources retrieved", sources.size());
7992

80-
8193
// Replace whole prompt is not supported yet
8294
var semanticSearchChat = new SemanticSearchChat(questionOrConversation, sources, options.getPromptTemplate(), false, options.isSuggestFollowupQuestions());
8395
var chatCompletionsOptions = ChatGPTUtils.buildDefaultChatCompletionsOptions(semanticSearchChat.getMessages());
8496

85-
// STEP 3: Generate a contextual and content specific answer using the search results and chat history
86-
Flux<ChatCompletions> chatCompletions = Flux.fromIterable(openAIProxy.getChatCompletionsStream(chatCompletionsOptions));
87-
88-
return chatCompletions
89-
.flatMap(completion -> {
90-
if (completion.getUsage() != null) {
91-
LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
92-
completion.getUsage().getPromptTokens(),
93-
completion.getUsage().getCompletionTokens(),
94-
completion.getUsage().getTotalTokens());
95-
}
96-
97-
return Flux.fromIterable(completion.getChoices())
98-
.filter(chatChoice -> chatChoice.getDelta().getContent() != null)
99-
.map(choice -> {
100-
return new RAGResponse.Builder()
101-
.question(ChatGPTUtils.getLastUserQuestion(questionOrConversation.getMessages()))
102-
.prompt(ChatGPTUtils.formatAsChatML(semanticSearchChat.getMessages()))
103-
.answer(choice.getDelta().getContent())
104-
.sources(sources)
105-
.build();
106-
});
107-
});
108-
109-
97+
AtomicInteger counter = new AtomicInteger(0);
98+
99+
IterableStream<ChatCompletions> completions = openAIProxy.getChatCompletionsStream(chatCompletionsOptions);
100+
101+
for (ChatCompletions completion : completions) {
102+
if (completion.getUsage() != null) {
103+
LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
104+
completion.getUsage().getPromptTokens(),
105+
completion.getUsage().getCompletionTokens(),
106+
completion.getUsage().getTotalTokens());
107+
}
108+
109+
List<ChatChoice> choices = completion.getChoices();
110+
111+
for (ChatChoice choice : choices) {
112+
if (choice.getDelta().getContent() == null) {
113+
continue;
114+
}
115+
116+
RAGResponse ragResponse = new RAGResponse.Builder()
117+
.question(ChatGPTUtils.getLastUserQuestion(questionOrConversation.getMessages()))
118+
.prompt(ChatGPTUtils.formatAsChatML(semanticSearchChat.getMessages()))
119+
.answer(choice.getDelta().getContent())
120+
.sources(sources)
121+
.build();
122+
123+
int index = counter.getAndIncrement();
124+
ChatResponse response;
125+
if (index == 0) {
126+
response = ChatResponse.buildChatResponse(ragResponse);
127+
} else {
128+
response = ChatResponse.buildChatDeltaResponse(index, ragResponse);
129+
}
130+
131+
try {
132+
String value = objectMapper.writeValueAsString(response) + "\n";
133+
outputStream.write(value.getBytes());
134+
outputStream.flush();
135+
} catch (IOException e) {
136+
throw new RuntimeException(e);
137+
}
138+
}
139+
}
110140
}
111-
112-
113141
}

app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/controller/ChatController.java

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.springframework.web.bind.annotation.PostMapping;
2020
import org.springframework.web.bind.annotation.RequestBody;
2121
import org.springframework.web.bind.annotation.RestController;
22+
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
2223
import reactor.core.publisher.Flux;
2324

2425
import java.util.ArrayList;
@@ -39,19 +40,19 @@ public ChatController(RAGApproachFactory<ChatGPTConversation, RAGResponse> ragAp
3940
value = "/api/chat",
4041
produces = MediaType.APPLICATION_NDJSON_VALUE
4142
)
42-
public Flux<ChatResponse> openAIAskStream(
43+
public ResponseEntity<StreamingResponseBody> openAIAskStream(
4344
@RequestBody ChatAppRequest chatRequest
4445
) {
4546
LOGGER.info("Received request for chat api with approach[{}]", chatRequest.approach());
4647

4748
if (!StringUtils.hasText(chatRequest.approach())) {
4849
LOGGER.warn("approach cannot be null in CHAT request");
49-
return Flux.error(new IllegalArgumentException("approach cannot be null in CHAT request"));
50+
return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(null);
5051
}
5152

5253
if (chatRequest.messages() == null || chatRequest.messages().isEmpty()) {
5354
LOGGER.warn("history cannot be null in Chat request");
54-
return Flux.error(new IllegalArgumentException("history cannot be null in Chat request"));
55+
return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(null);
5556
}
5657

5758
var ragOptions = new RAGOptions.Builder()
@@ -66,21 +67,23 @@ public Flux<ChatResponse> openAIAskStream(
6667

6768
RAGApproach<ChatGPTConversation, RAGResponse> ragApproach = ragApproachFactory.createApproach(chatRequest.approach(), RAGType.CHAT, ragOptions);
6869

69-
7070
ChatGPTConversation chatGPTConversation = convertToChatGPT(chatRequest.messages());
7171

72+
7273
Flux<Integer> counter = Flux.range(0, Integer.MAX_VALUE);
7374

74-
return ragApproach
75-
.runStreaming(chatGPTConversation, ragOptions)
76-
.zipWith(counter)
77-
.map((data) -> {
78-
if (data.getT2() == 0) {
79-
return ChatResponse.buildChatResponse(data.getT1());
80-
} else {
81-
return ChatResponse.buildChatDeltaResponse(data.getT2(), data.getT1());
82-
}
83-
});
75+
StreamingResponseBody response = output -> {
76+
try {
77+
ragApproach.runStreaming(chatGPTConversation, ragOptions, output);
78+
} finally {
79+
output.flush();
80+
output.close();
81+
}
82+
};
83+
84+
return ResponseEntity.ok()
85+
.contentType(MediaType.APPLICATION_NDJSON)
86+
.body(response);
8487
}
8588

8689
@PostMapping(

0 commit comments

Comments
 (0)