1
1
package com .microsoft .openai .samples .rag .ask .controller ;
2
2
3
- import com .microsoft .openai .samples .rag .approaches .*;
4
- import com .microsoft .openai .samples .rag .controller .Overrides ;
3
+ import com .microsoft .openai .samples .rag .approaches .RAGApproach ;
4
+ import com .microsoft .openai .samples .rag .approaches .RAGApproachFactory ;
5
+ import com .microsoft .openai .samples .rag .approaches .RAGOptions ;
6
+ import com .microsoft .openai .samples .rag .approaches .RAGResponse ;
7
+ import com .microsoft .openai .samples .rag .approaches .RAGType ;
8
+ import com .microsoft .openai .samples .rag .controller .ChatAppRequest ;
9
+ import com .microsoft .openai .samples .rag .controller .ChatResponse ;
10
+ import com .microsoft .openai .samples .rag .controller .ResponseChoice ;
11
+ import com .microsoft .openai .samples .rag .controller .ResponseContext ;
12
+ import com .microsoft .openai .samples .rag .controller .ResponseMessage ;
13
+ import com .microsoft .openai .samples .rag .common .ChatGPTMessage ;
5
14
import org .slf4j .Logger ;
6
15
import org .slf4j .LoggerFactory ;
7
16
import org .springframework .http .HttpStatus ;
11
20
import org .springframework .web .bind .annotation .RequestBody ;
12
21
import org .springframework .web .bind .annotation .RestController ;
13
22
14
- import java .util .Arrays ;
23
+ import java .util .Collections ;
15
24
import java .util .List ;
16
25
17
26
@ RestController
@@ -25,59 +34,34 @@ public class AskController {
25
34
}
26
35
27
36
@ PostMapping ("/api/ask" )
28
- public ResponseEntity <AskResponse > openAIAsk (@ RequestBody AskRequest askRequest ) {
29
- LOGGER .info ("Received request for ask api with question [{}] and approach[{}]" , askRequest .getQuestion (), askRequest .getApproach ());
37
+ public ResponseEntity <ChatResponse > openAIAsk (@ RequestBody ChatAppRequest askRequest ) {
38
+ String question = askRequest .messages ().get (askRequest .messages ().size () - 1 ).content ();
39
+ LOGGER .info ("Received request for ask api with question [{}] and approach[{}]" , question , askRequest .approach ());
30
40
31
- if (!StringUtils .hasText (askRequest .getApproach ())) {
41
+ if (!StringUtils .hasText (askRequest .approach ())) {
32
42
LOGGER .warn ("approach cannot be null in ASK request" );
33
43
return ResponseEntity .status (HttpStatus .BAD_REQUEST ).body (null );
34
44
}
35
45
36
- if (!StringUtils .hasText (askRequest . getQuestion () )) {
46
+ if (!StringUtils .hasText (question )) {
37
47
LOGGER .warn ("question cannot be null in ASK request" );
38
48
return ResponseEntity .status (HttpStatus .BAD_REQUEST ).body (null );
39
49
}
40
50
41
51
var ragOptions = new RAGOptions .Builder ()
42
- .retrievialMode (askRequest .getOverrides ().getRetrievalMode ())
43
- .semanticKernelMode (askRequest .getOverrides ().getSemantickKernelMode ())
44
- .semanticRanker (askRequest .getOverrides ().isSemanticRanker ())
45
- .semanticCaptions (askRequest .getOverrides ().isSemanticCaptions ())
46
- .excludeCategory (askRequest .getOverrides ().getExcludeCategory ())
47
- .promptTemplate (askRequest .getOverrides ().getPromptTemplate ())
48
- .top (askRequest .getOverrides ().getTop ())
52
+ .retrievialMode (askRequest .context ().overrides (). retrieval_mode (). name ())
53
+ .semanticKernelMode (askRequest .context ().overrides (). semantic_kernel_mode ())
54
+ .semanticRanker (askRequest .context ().overrides (). semantic_ranker ())
55
+ .semanticCaptions (askRequest .context ().overrides (). semantic_captions ())
56
+ .excludeCategory (askRequest .context ().overrides (). exclude_category ())
57
+ .promptTemplate (askRequest .context ().overrides (). prompt_template ())
58
+ .top (askRequest .context ().overrides (). top ())
49
59
.build ();
50
60
51
- RAGApproach <String , RAGResponse > ragApproach = ragApproachFactory .createApproach (askRequest .getApproach (), RAGType .ASK , ragOptions );
61
+ RAGApproach <String , RAGResponse > ragApproach = ragApproachFactory .createApproach (askRequest .approach (), RAGType .ASK , ragOptions );
52
62
53
- //set empty overrides if not provided
54
- if (askRequest .getOverrides () == null ) {
55
- askRequest .setOverrides (new Overrides ());
56
- }
57
-
58
-
59
-
60
- return ResponseEntity .ok (buildAskResponse (ragApproach .run (askRequest .getQuestion (), ragOptions )));
63
+ return ResponseEntity .ok (ChatResponse .buildChatResponse (ragApproach .run (question , ragOptions )));
61
64
}
62
65
63
- private AskResponse buildAskResponse (RAGResponse ragResponse ) {
64
- var askResponse = new AskResponse ();
65
-
66
- askResponse .setAnswer (ragResponse .getAnswer ());
67
- List <String > dataPoints ;
68
- if (ragResponse .getSourcesAsText () != null && !ragResponse .getSourcesAsText ().isEmpty ()) {
69
- dataPoints = Arrays .asList (ragResponse .getSourcesAsText ().split ("\n " ));
70
- } else {
71
- dataPoints = ragResponse .getSources ().stream ()
72
- .map (source -> source .getSourceName () + ": " + source .getSourceContent ())
73
- .toList ();
74
- }
75
-
76
- askResponse .setDataPoints (dataPoints );
77
-
78
- askResponse .setThoughts ("Question:<br>" + ragResponse .getQuestion () + "<br><br>Prompt:<br>" + ragResponse .getPrompt ().replace ("\n " , "<br>" ));
79
-
80
- return askResponse ;
81
- }
82
66
83
67
}
0 commit comments