2
2
3
3
import com .oracle .bmc .generativeaiinference .GenerativeAiInferenceClient ;
4
4
import com .oracle .bmc .generativeaiinference .model .*;
5
+ import com .oracle .bmc .generativeaiinference .requests .ChatRequest ;
5
6
import com .oracle .bmc .generativeaiinference .requests .GenerateTextRequest ;
6
7
import com .oracle .bmc .generativeaiinference .requests .SummarizeTextRequest ;
8
+ import com .oracle .bmc .generativeaiinference .responses .ChatResponse ;
7
9
import com .oracle .bmc .generativeaiinference .responses .GenerateTextResponse ;
8
10
import com .oracle .bmc .generativeaiinference .responses .SummarizeTextResponse ;
11
+ import com .oracle .bmc .http .client .jersey .WrappedResponseInputStream ;
12
+ import org .hibernate .boot .archive .scan .internal .StandardScanner ;
9
13
import org .springframework .beans .factory .annotation .Autowired ;
10
14
import org .springframework .beans .factory .annotation .Value ;
11
15
import org .springframework .stereotype .Service ;
12
16
17
+ import java .io .*;
18
+ import java .nio .charset .StandardCharsets ;
19
+ import java .util .List ;
13
20
import java .util .stream .Collectors ;
14
21
15
22
@ Service
@@ -18,52 +25,47 @@ public class OCIGenAIService {
18
25
private String COMPARTMENT_ID ;
19
26
20
27
@ Autowired
21
- private GenerativeAiInferenceClientService generativeAiInferenceClientService ;
28
+ private GenAiInferenceClientService generativeAiInferenceClientService ;
22
29
23
30
public String resolvePrompt (String input , String modelId , boolean finetune ) {
24
- // Build generate text request, send, and get response
25
- CohereLlmInferenceRequest llmInferenceRequest = CohereLlmInferenceRequest .builder ()
26
- .prompt (input )
27
- .maxTokens (600 )
28
- .temperature ((double ) 1 )
29
- .frequencyPenalty ((double ) 0 )
30
- .topP ((double ) 0.75 )
31
- .isStream (false )
32
- .isEcho (false )
33
- .build ();
31
+ CohereChatRequest cohereChatRequest = CohereChatRequest .builder ()
32
+ .message (input )
33
+ .maxTokens (600 )
34
+ .temperature ((double ) 1 )
35
+ .frequencyPenalty ((double ) 0 )
36
+ .topP ((double ) 0.75 )
37
+ .topK (0 )
38
+ .isStream (false ) // TODO websockets and streams
39
+ .build ();
34
40
35
- GenerateTextDetails generateTextDetails = GenerateTextDetails .builder ()
36
- .servingMode (finetune ? DedicatedServingMode .builder ().endpointId (modelId ).build ()
37
- : OnDemandServingMode .builder ().modelId (modelId ).build ())
38
- .compartmentId (COMPARTMENT_ID )
39
- .inferenceRequest (llmInferenceRequest )
40
- .build ();
41
- GenerateTextRequest generateTextRequest = GenerateTextRequest .builder ()
42
- .generateTextDetails (generateTextDetails )
43
- .build ();
44
- GenerativeAiInferenceClient client = generativeAiInferenceClientService .getClient ();
45
- GenerateTextResponse generateTextResponse = client .generateText (generateTextRequest );
46
- CohereLlmInferenceResponse response = (CohereLlmInferenceResponse ) generateTextResponse
47
- .getGenerateTextResult ().getInferenceResponse ();
48
- String responseTexts = response .getGeneratedTexts ()
49
- .stream ()
50
- .map (t -> t .getText ())
51
- .collect (Collectors .joining ("," ));
52
- return responseTexts ;
41
+ ChatDetails chatDetails = ChatDetails .builder ()
42
+ .servingMode (OnDemandServingMode .builder ().modelId (modelId ).build ())
43
+ .compartmentId (COMPARTMENT_ID )
44
+ .chatRequest (cohereChatRequest )
45
+ .build ();
46
+
47
+ ChatRequest request = ChatRequest .builder ()
48
+ .chatDetails (chatDetails )
49
+ .build ();
50
+ ChatResponse response = generativeAiInferenceClientService .getClient ().chat (request );
51
+ ChatResult chatResult = response .getChatResult ();
52
+
53
+ BaseChatResponse baseChatResponse = chatResult .getChatResponse ();
54
+ if (baseChatResponse instanceof CohereChatResponse ) {
55
+ return ((CohereChatResponse )baseChatResponse ).getText ();
56
+ } else if (baseChatResponse instanceof GenericChatResponse ) {
57
+ List <ChatChoice > choices = ((GenericChatResponse ) baseChatResponse ).getChoices ();
58
+ List <ChatContent > contents = choices .get (choices .size () - 1 ).getMessage ().getContent ();
59
+ ChatContent content = contents .get (contents .size () - 1 );
60
+ if (content instanceof TextContent ) {
61
+ return ((TextContent ) content ).getText ();
62
+ }
63
+ }
64
+ throw new IllegalStateException ("Unexpected chat response type: " + baseChatResponse .getClass ().getName ());
53
65
}
54
66
55
- public String summaryText (String input , String modelId ) {
56
- SummarizeTextDetails summarizeTextDetails = SummarizeTextDetails .builder ()
57
- .servingMode (OnDemandServingMode .builder ().modelId (modelId ).build ())
58
- .compartmentId (COMPARTMENT_ID )
59
- .input (input )
60
- .build ();
61
- SummarizeTextRequest request = SummarizeTextRequest .builder ()
62
- .summarizeTextDetails (summarizeTextDetails )
63
- .build ();
64
- GenerativeAiInferenceClient client = generativeAiInferenceClientService .getClient ();
65
- SummarizeTextResponse summarizeTextResponse = client .summarizeText (request );
66
- String summaryText = summarizeTextResponse .getSummarizeTextResult ().getSummary ();
67
- return summaryText ;
67
+ public String summaryText (String input , String modelId , boolean finetuned ) {
68
+ String response = resolvePrompt ("Summarize this:\n " + input , modelId , finetuned );
69
+ return response ;
68
70
}
69
71
}
0 commit comments