1
1
package com .microsoft .openai .samples .rag .ask .approaches ;
2
2
3
+ import com .azure .ai .openai .models .ChatChoice ;
3
4
import com .azure .ai .openai .models .ChatCompletions ;
5
+ import com .azure .core .util .IterableStream ;
6
+ import com .fasterxml .jackson .databind .ObjectMapper ;
4
7
import com .microsoft .openai .samples .rag .approaches .ContentSource ;
5
8
import com .microsoft .openai .samples .rag .approaches .RAGApproach ;
6
9
import com .microsoft .openai .samples .rag .approaches .RAGOptions ;
7
10
import com .microsoft .openai .samples .rag .approaches .RAGResponse ;
8
11
import com .microsoft .openai .samples .rag .common .ChatGPTUtils ;
12
+ import com .microsoft .openai .samples .rag .controller .ChatResponse ;
9
13
import com .microsoft .openai .samples .rag .proxy .OpenAIProxy ;
10
14
import com .microsoft .openai .samples .rag .retrieval .FactsRetrieverProvider ;
11
15
import com .microsoft .openai .samples .rag .retrieval .Retriever ;
12
16
import org .slf4j .Logger ;
13
17
import org .slf4j .LoggerFactory ;
14
18
import org .springframework .stereotype .Component ;
15
- import reactor .core .publisher .Flux ;
16
19
20
+ import java .io .IOException ;
17
21
import java .io .OutputStream ;
18
22
import java .util .List ;
19
23
@@ -28,10 +32,12 @@ public class PlainJavaAskApproach implements RAGApproach<String, RAGResponse> {
28
32
private static final Logger LOGGER = LoggerFactory .getLogger (PlainJavaAskApproach .class );
29
33
private final OpenAIProxy openAIProxy ;
30
34
private final FactsRetrieverProvider factsRetrieverProvider ;
35
+ private final ObjectMapper objectMapper ;
31
36
32
- public PlainJavaAskApproach (FactsRetrieverProvider factsRetrieverProvider , OpenAIProxy openAIProxy ) {
37
+ public PlainJavaAskApproach (FactsRetrieverProvider factsRetrieverProvider , OpenAIProxy openAIProxy , ObjectMapper objectMapper ) {
33
38
this .factsRetrieverProvider = factsRetrieverProvider ;
34
39
this .openAIProxy = openAIProxy ;
40
+ this .objectMapper = objectMapper ;
35
41
}
36
42
37
43
/**
@@ -78,13 +84,7 @@ public RAGResponse run(String question, RAGOptions options) {
78
84
}
79
85
80
86
@ Override
81
- public void runStreaming (String questionOrConversation , RAGOptions options , OutputStream outputStream ) {
82
- throw new UnsupportedOperationException ("Streaming not supported for PlainJavaAskApproach" );
83
- }
84
- /*
85
- @Override
86
- public Flux<RAGResponse> runStreaming(String question, RAGOptions options) {
87
-
87
+ public void runStreaming (String question , RAGOptions options , OutputStream outputStream ) {
88
88
//Get instance of retriever based on the retrieval mode: hybryd, text, vectors.
89
89
Retriever factsRetriever = factsRetrieverProvider .getFactsRetriever (options );
90
90
List <ContentSource > sources = factsRetriever .retrieveFromQuestion (question , options );
@@ -105,23 +105,43 @@ public Flux<RAGResponse> runStreaming(String question, RAGOptions options) {
105
105
var groundedChatMessages = answerQuestionChatTemplate .getMessages (question , sources );
106
106
var chatCompletionsOptions = ChatGPTUtils .buildDefaultChatCompletionsOptions (groundedChatMessages );
107
107
108
- Flux<ChatCompletions> completions = Flux.fromIterable(openAIProxy.getChatCompletionsStream(chatCompletionsOptions));
109
- return completions
110
- .flatMap(completion -> {
111
- LOGGER.info("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]",
112
- completion.getUsage().getPromptTokens(),
113
- completion.getUsage().getCompletionTokens(),
114
- completion.getUsage().getTotalTokens());
115
-
116
- return Flux.fromIterable(completion.getChoices())
117
- .map(choice -> new RAGResponse.Builder()
118
- .question(question)
119
- .prompt(ChatGPTUtils.formatAsChatML(groundedChatMessages))
120
- .answer(choice.getMessage().getContent())
121
- .sources(sources)
122
- .build());
123
- });
108
+ IterableStream <ChatCompletions > completions = openAIProxy .getChatCompletionsStream (chatCompletionsOptions );
109
+ int index = 0 ;
110
+ for (ChatCompletions completion : completions ) {
111
+
112
+ LOGGER .info ("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]" ,
113
+ completion .getUsage ().getPromptTokens (),
114
+ completion .getUsage ().getCompletionTokens (),
115
+ completion .getUsage ().getTotalTokens ());
116
+
117
+ for (ChatChoice choice : completion .getChoices ()) {
118
+ if (choice .getDelta ().getContent () == null ) {
119
+ continue ;
120
+ }
121
+
122
+ RAGResponse ragResponse = new RAGResponse .Builder ()
123
+ .question (question )
124
+ .prompt (ChatGPTUtils .formatAsChatML (groundedChatMessages ))
125
+ .answer (choice .getMessage ().getContent ())
126
+ .sources (sources )
127
+ .build ();
128
+
129
+ ChatResponse response ;
130
+ if (index == 0 ) {
131
+ response = ChatResponse .buildChatResponse (ragResponse );
132
+ } else {
133
+ response = ChatResponse .buildChatDeltaResponse (index , ragResponse );
134
+ }
135
+ index ++;
136
+
137
+ try {
138
+ String value = objectMapper .writeValueAsString (response ) + "\n " ;
139
+ outputStream .write (value .getBytes ());
140
+ outputStream .flush ();
141
+ } catch (IOException e ) {
142
+ throw new RuntimeException (e );
143
+ }
144
+ }
145
+ }
124
146
}
125
-
126
- */
127
147
}
0 commit comments