Skip to content

Commit d36cec8

Browse files
committed
Add stream to Ask
1 parent 627e685 commit d36cec8

File tree

2 files changed

+70
-4
lines changed

2 files changed

+70
-4
lines changed

app/backend/src/main/java/com/microsoft/openai/samples/rag/ask/controller/AskController.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
import org.slf4j.Logger;
1111
import org.slf4j.LoggerFactory;
1212
import org.springframework.http.HttpStatus;
13+
import org.springframework.http.MediaType;
1314
import org.springframework.http.ResponseEntity;
1415
import org.springframework.util.StringUtils;
1516
import org.springframework.web.bind.annotation.PostMapping;
1617
import org.springframework.web.bind.annotation.RequestBody;
1718
import org.springframework.web.bind.annotation.RestController;
19+
import org.springframework.web.server.ResponseStatusException;
20+
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
1821

1922
@RestController
2023
public class AskController {
@@ -26,8 +29,64 @@ public class AskController {
2629
this.ragApproachFactory = ragApproachFactory;
2730
}
2831

32+
@PostMapping(
33+
value = "/api/ask",
34+
produces = MediaType.APPLICATION_NDJSON_VALUE
35+
)
36+
public ResponseEntity openAIAskStream(
37+
@RequestBody ChatAppRequest askRequest
38+
) {
39+
if (!askRequest.stream()) {
40+
LOGGER.warn("Requested a content-type of application/ndjson however did not requested streaming. Please use a content-type of application/json");
41+
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Requested a content-type of application/ndjson however did not requested streaming. Please use a content-type of application/json");
42+
}
43+
44+
String question = askRequest.messages().get(askRequest.messages().size() - 1).content();
45+
LOGGER.info("Received request for ask api with question [{}] and approach[{}]", question, askRequest.approach());
46+
47+
if (!StringUtils.hasText(askRequest.approach())) {
48+
LOGGER.warn("approach cannot be null in ASK request");
49+
return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(null);
50+
}
51+
52+
if (!StringUtils.hasText(question)) {
53+
LOGGER.warn("question cannot be null in ASK request");
54+
return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(null);
55+
}
56+
57+
var ragOptions = new RAGOptions.Builder()
58+
.retrievialMode(askRequest.context().overrides().retrieval_mode().name())
59+
.semanticKernelMode(askRequest.context().overrides().semantic_kernel_mode())
60+
.semanticRanker(askRequest.context().overrides().semantic_ranker())
61+
.semanticCaptions(askRequest.context().overrides().semantic_captions())
62+
.excludeCategory(askRequest.context().overrides().exclude_category())
63+
.promptTemplate(askRequest.context().overrides().prompt_template())
64+
.top(askRequest.context().overrides().top())
65+
.build();
66+
67+
RAGApproach<String, RAGResponse> ragApproach = ragApproachFactory.createApproach(askRequest.approach(), RAGType.ASK, ragOptions);
68+
69+
StreamingResponseBody response = output -> {
70+
try {
71+
ragApproach.runStreaming(question, ragOptions, output);
72+
} finally {
73+
output.flush();
74+
output.close();
75+
}
76+
};
77+
78+
return ResponseEntity.ok()
79+
.contentType(MediaType.APPLICATION_NDJSON)
80+
.body(response);
81+
}
82+
2983
@PostMapping("/api/ask")
3084
public ResponseEntity<ChatResponse> openAIAsk(@RequestBody ChatAppRequest askRequest) {
85+
if (askRequest.stream()) {
86+
LOGGER.warn("Requested a content-type of application/json however also requested streaming. Please use a content-type of application/ndjson");
87+
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Requested a content-type of application/json however also requested streaming. Please use a content-type of application/ndjson");
88+
}
89+
3190
String question = askRequest.messages().get(askRequest.messages().size() - 1).content();
3291
LOGGER.info("Received request for ask api with question [{}] and approach[{}]", question, askRequest.approach());
3392

app/backend/src/main/java/com/microsoft/openai/samples/rag/chat/controller/ChatController.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import org.springframework.web.bind.annotation.PostMapping;
2020
import org.springframework.web.bind.annotation.RequestBody;
2121
import org.springframework.web.bind.annotation.RestController;
22+
import org.springframework.web.server.ResponseStatusException;
2223
import org.springframework.web.servlet.mvc.method.annotation.StreamingResponseBody;
23-
import reactor.core.publisher.Flux;
2424

2525
import java.util.ArrayList;
2626
import java.util.Collection;
@@ -43,6 +43,11 @@ public ChatController(RAGApproachFactory<ChatGPTConversation, RAGResponse> ragAp
4343
public ResponseEntity<StreamingResponseBody> openAIAskStream(
4444
@RequestBody ChatAppRequest chatRequest
4545
) {
46+
if (!chatRequest.stream()) {
47+
LOGGER.warn("Requested a content-type of application/ndjson however did not requested streaming. Please use a content-type of application/json");
48+
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Requested a content-type of application/ndjson however did not requested streaming. Please use a content-type of application/json");
49+
}
50+
4651
LOGGER.info("Received request for chat api with approach[{}]", chatRequest.approach());
4752

4853
if (!StringUtils.hasText(chatRequest.approach())) {
@@ -69,9 +74,6 @@ public ResponseEntity<StreamingResponseBody> openAIAskStream(
6974

7075
ChatGPTConversation chatGPTConversation = convertToChatGPT(chatRequest.messages());
7176

72-
73-
Flux<Integer> counter = Flux.range(0, Integer.MAX_VALUE);
74-
7577
StreamingResponseBody response = output -> {
7678
try {
7779
ragApproach.runStreaming(chatGPTConversation, ragOptions, output);
@@ -91,6 +93,11 @@ public ResponseEntity<StreamingResponseBody> openAIAskStream(
9193
produces = MediaType.APPLICATION_JSON_VALUE
9294
)
9395
public ResponseEntity<ChatResponse> openAIAsk(@RequestBody ChatAppRequest chatRequest) {
96+
if (chatRequest.stream()) {
97+
LOGGER.warn("Requested a content-type of application/json however also requested streaming. Please use a content-type of application/ndjson");
98+
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "Requested a content-type of application/json however also requested streaming. Please use a content-type of application/ndjson");
99+
}
100+
94101
LOGGER.info("Received request for chat api with approach[{}]", chatRequest.approach());
95102

96103
if (!StringUtils.hasText(chatRequest.approach())) {

0 commit comments

Comments
 (0)