1
1
package com .microsoft .openai .samples .rag .ask .approaches ;
2
2
3
+ import com .azure .ai .openai .models .ChatChoice ;
3
4
import com .azure .ai .openai .models .ChatCompletions ;
5
+ import com .azure .core .util .IterableStream ;
6
+ import com .fasterxml .jackson .databind .ObjectMapper ;
4
7
import com .microsoft .openai .samples .rag .approaches .ContentSource ;
5
8
import com .microsoft .openai .samples .rag .approaches .RAGApproach ;
6
9
import com .microsoft .openai .samples .rag .approaches .RAGOptions ;
7
10
import com .microsoft .openai .samples .rag .approaches .RAGResponse ;
8
11
import com .microsoft .openai .samples .rag .common .ChatGPTUtils ;
9
- import com .microsoft .openai .samples .rag .retrieval . FactsRetrieverProvider ;
12
+ import com .microsoft .openai .samples .rag .controller . ChatResponse ;
10
13
import com .microsoft .openai .samples .rag .proxy .OpenAIProxy ;
14
+ import com .microsoft .openai .samples .rag .retrieval .FactsRetrieverProvider ;
11
15
import com .microsoft .openai .samples .rag .retrieval .Retriever ;
12
16
import org .slf4j .Logger ;
13
17
import org .slf4j .LoggerFactory ;
14
18
import org .springframework .stereotype .Component ;
15
19
20
+ import java .io .IOException ;
21
+ import java .io .OutputStream ;
16
22
import java .util .List ;
17
23
18
24
/**
19
25
* Simple retrieve-then-read java implementation, using the Cognitive Search and OpenAI APIs directly. It first retrieves
20
- * top documents from search, then constructs a prompt with them, and then uses OpenAI to generate a completion
21
- * (answer) with that prompt.
26
+ * top documents from search, then constructs a prompt with them, and then uses OpenAI to generate a completion
27
+ * (answer) with that prompt.
22
28
*/
23
29
@ Component
24
30
public class PlainJavaAskApproach implements RAGApproach <String , RAGResponse > {
25
31
26
32
private static final Logger LOGGER = LoggerFactory .getLogger (PlainJavaAskApproach .class );
27
33
private final OpenAIProxy openAIProxy ;
28
34
private final FactsRetrieverProvider factsRetrieverProvider ;
35
+ private final ObjectMapper objectMapper ;
29
36
30
- public PlainJavaAskApproach (FactsRetrieverProvider factsRetrieverProvider , OpenAIProxy openAIProxy ) {
37
+ public PlainJavaAskApproach (FactsRetrieverProvider factsRetrieverProvider , OpenAIProxy openAIProxy , ObjectMapper objectMapper ) {
31
38
this .factsRetrieverProvider = factsRetrieverProvider ;
32
39
this .openAIProxy = openAIProxy ;
40
+ this .objectMapper = objectMapper ;
33
41
}
34
42
35
43
/**
@@ -39,8 +47,6 @@ public PlainJavaAskApproach(FactsRetrieverProvider factsRetrieverProvider, OpenA
39
47
*/
40
48
@ Override
41
49
public RAGResponse run (String question , RAGOptions options ) {
42
- //TODO exception handling
43
-
44
50
//Get instance of retriever based on the retrieval mode: hybryd, text, vectors.
45
51
Retriever factsRetriever = factsRetrieverProvider .getFactsRetriever (options );
46
52
List <ContentSource > sources = factsRetriever .retrieveFromQuestion (question , options );
@@ -51,14 +57,14 @@ public RAGResponse run(String question, RAGOptions options) {
51
57
var customPromptEmpty = (customPrompt == null ) || (customPrompt != null && customPrompt .isEmpty ());
52
58
53
59
//true will replace the default prompt. False will add custom prompt as suffix to the default prompt
54
- var replacePrompt = !customPromptEmpty && !customPrompt .startsWith ("|" );
55
- if (!replacePrompt && !customPromptEmpty ){
60
+ var replacePrompt = !customPromptEmpty && !customPrompt .startsWith ("|" );
61
+ if (!replacePrompt && !customPromptEmpty ) {
56
62
customPrompt = customPrompt .substring (1 );
57
63
}
58
64
59
65
var answerQuestionChatTemplate = new AnswerQuestionChatTemplate (customPrompt , replacePrompt );
60
66
61
- var groundedChatMessages = answerQuestionChatTemplate .getMessages (question ,sources );
67
+ var groundedChatMessages = answerQuestionChatTemplate .getMessages (question , sources );
62
68
var chatCompletionsOptions = ChatGPTUtils .buildDefaultChatCompletionsOptions (groundedChatMessages );
63
69
64
70
// STEP 3: Generate a contextual and content specific answer using the retrieve facts
@@ -75,8 +81,67 @@ public RAGResponse run(String question, RAGOptions options) {
75
81
.answer (chatCompletions .getChoices ().get (0 ).getMessage ().getContent ())
76
82
.sources (sources )
77
83
.build ();
78
-
79
84
}
80
85
86
+ @ Override
87
+ public void runStreaming (String question , RAGOptions options , OutputStream outputStream ) {
88
+ //Get instance of retriever based on the retrieval mode: hybryd, text, vectors.
89
+ Retriever factsRetriever = factsRetrieverProvider .getFactsRetriever (options );
90
+ List <ContentSource > sources = factsRetriever .retrieveFromQuestion (question , options );
91
+ LOGGER .info ("Total {} sources found in cognitive search for keyword search query[{}]" , sources .size (),
92
+ question );
81
93
94
+ var customPrompt = options .getPromptTemplate ();
95
+ var customPromptEmpty = (customPrompt == null ) || (customPrompt != null && customPrompt .isEmpty ());
96
+
97
+ //true will replace the default prompt. False will add custom prompt as suffix to the default prompt
98
+ var replacePrompt = !customPromptEmpty && !customPrompt .startsWith ("|" );
99
+ if (!replacePrompt && !customPromptEmpty ) {
100
+ customPrompt = customPrompt .substring (1 );
101
+ }
102
+
103
+ var answerQuestionChatTemplate = new AnswerQuestionChatTemplate (customPrompt , replacePrompt );
104
+
105
+ var groundedChatMessages = answerQuestionChatTemplate .getMessages (question , sources );
106
+ var chatCompletionsOptions = ChatGPTUtils .buildDefaultChatCompletionsOptions (groundedChatMessages );
107
+
108
+ IterableStream <ChatCompletions > completions = openAIProxy .getChatCompletionsStream (chatCompletionsOptions );
109
+ int index = 0 ;
110
+ for (ChatCompletions completion : completions ) {
111
+
112
+ LOGGER .info ("Chat completion generated with Prompt Tokens[{}], Completions Tokens[{}], Total Tokens[{}]" ,
113
+ completion .getUsage ().getPromptTokens (),
114
+ completion .getUsage ().getCompletionTokens (),
115
+ completion .getUsage ().getTotalTokens ());
116
+
117
+ for (ChatChoice choice : completion .getChoices ()) {
118
+ if (choice .getDelta ().getContent () == null ) {
119
+ continue ;
120
+ }
121
+
122
+ RAGResponse ragResponse = new RAGResponse .Builder ()
123
+ .question (question )
124
+ .prompt (ChatGPTUtils .formatAsChatML (groundedChatMessages ))
125
+ .answer (choice .getMessage ().getContent ())
126
+ .sources (sources )
127
+ .build ();
128
+
129
+ ChatResponse response ;
130
+ if (index == 0 ) {
131
+ response = ChatResponse .buildChatResponse (ragResponse );
132
+ } else {
133
+ response = ChatResponse .buildChatDeltaResponse (index , ragResponse );
134
+ }
135
+ index ++;
136
+
137
+ try {
138
+ String value = objectMapper .writeValueAsString (response ) + "\n " ;
139
+ outputStream .write (value .getBytes ());
140
+ outputStream .flush ();
141
+ } catch (IOException e ) {
142
+ throw new RuntimeException (e );
143
+ }
144
+ }
145
+ }
146
+ }
82
147
}
0 commit comments