11package dev .victormartin .oci .genai .backend .backend .service ;
22
3- import com .oracle .bmc .generativeaiinference .GenerativeAiInferenceClient ;
43import com .oracle .bmc .generativeaiinference .model .*;
4+ import com .oracle .bmc .generativeaiinference .model .Message ;
55import com .oracle .bmc .generativeaiinference .requests .ChatRequest ;
6- import com .oracle .bmc .generativeaiinference .requests .GenerateTextRequest ;
7- import com .oracle .bmc .generativeaiinference .requests .SummarizeTextRequest ;
86import com .oracle .bmc .generativeaiinference .responses .ChatResponse ;
9- import com .oracle .bmc .generativeaiinference .responses .GenerateTextResponse ;
10- import com .oracle .bmc .generativeaiinference .responses .SummarizeTextResponse ;
11- import com .oracle .bmc .http .client .jersey .WrappedResponseInputStream ;
12- import org .hibernate .boot .archive .scan .internal .StandardScanner ;
7+ import dev .victormartin .oci .genai .backend .backend .dao .GenAiModel ;
8+ import org .slf4j .Logger ;
9+ import org .slf4j .LoggerFactory ;
1310import org .springframework .beans .factory .annotation .Autowired ;
1411import org .springframework .beans .factory .annotation .Value ;
1512import org .springframework .stereotype .Service ;
1613
17- import java .io .*;
18- import java .nio .charset .StandardCharsets ;
14+ import java .util .ArrayList ;
1915import java .util .List ;
20- import java .util .stream .Collectors ;
2116
2217@ Service
2318public class OCIGenAIService {
19+
20+ Logger log = LoggerFactory .getLogger (OCIGenAIService .class );
21+
2422 @ Value ("${genai.compartment_id}" )
2523 private String COMPARTMENT_ID ;
2624
2725 @ Autowired
2826 private GenAiInferenceClientService generativeAiInferenceClientService ;
2927
30- public String resolvePrompt (String input , String modelId , boolean finetune ) {
31- CohereChatRequest cohereChatRequest = CohereChatRequest .builder ()
32- .message (input )
33- .maxTokens (600 )
34- .temperature ((double ) 1 )
35- .frequencyPenalty ((double ) 0 )
36- .topP ((double ) 0.75 )
37- .topK (0 )
38- .isStream (false ) // TODO websockets and streams
39- .build ();
28+ @ Autowired
29+ private GenAIModelsService genAIModelsService ;
4030
41- ChatDetails chatDetails = ChatDetails .builder ()
42- .servingMode (OnDemandServingMode .builder ().modelId (modelId ).build ())
43- .compartmentId (COMPARTMENT_ID )
44- .chatRequest (cohereChatRequest )
45- .build ();
31+ public String resolvePrompt (String input , String modelId , boolean finetune , boolean summarization ) {
32+
33+ List <GenAiModel > models = genAIModelsService .getModels ();
34+ GenAiModel currentModel = models .stream ()
35+ .filter (m -> modelId .equals (m .id ()))
36+ .findFirst ()
37+ .orElseThrow ();
38+
39+ log .info ("Model {} with finetune {}" , currentModel .name (), finetune ? "yes" : "no" );
40+
41+ double temperature = summarization ?0.0 :0.5 ;
42+
43+ String inputText = summarization ?"Summarize this text:\n " + input : input ;
44+
45+ ChatDetails chatDetails ;
46+ switch (currentModel .vendor ()) {
47+ case "cohere" :
48+ CohereChatRequest cohereChatRequest = CohereChatRequest .builder ()
49+ .message (inputText )
50+ .maxTokens (600 )
51+ .temperature (temperature )
52+ .frequencyPenalty ((double ) 0 )
53+ .topP (0.75 )
54+ .topK (0 )
55+ .isStream (false ) // TODO websockets and streams
56+ .build ();
57+
58+ chatDetails = ChatDetails .builder ()
59+ .servingMode (OnDemandServingMode .builder ().modelId (currentModel .id ()).build ())
60+ .compartmentId (COMPARTMENT_ID )
61+ .chatRequest (cohereChatRequest )
62+ .build ();
63+ break ;
64+ case "meta" :
65+ ChatContent content = TextContent .builder ()
66+ .text (inputText )
67+ .build ();
68+ List <ChatContent > contents = new ArrayList <>();
69+ contents .add (content );
70+ List <Message > messages = new ArrayList <>();
71+ Message message = new UserMessage (contents , "user" );
72+ messages .add (message );
73+ GenericChatRequest genericChatRequest = GenericChatRequest .builder ()
74+ .messages (messages )
75+ .maxTokens (600 )
76+ .temperature ((double )1 )
77+ .frequencyPenalty ((double )0 )
78+ .presencePenalty ((double )0 )
79+ .topP (0.75 )
80+ .topK (-1 )
81+ .isStream (false )
82+ .build ();
83+ chatDetails = ChatDetails .builder ()
84+ .servingMode (OnDemandServingMode .builder ().modelId (currentModel .id ()).build ())
85+ .compartmentId (COMPARTMENT_ID )
86+ .chatRequest (genericChatRequest )
87+ .build ();
88+ break ;
89+ default :
90+ throw new IllegalStateException ("Unexpected value: " + currentModel .vendor ());
91+ }
4692
4793 ChatRequest request = ChatRequest .builder ()
4894 .chatDetails (chatDetails )
@@ -65,7 +111,7 @@ public String resolvePrompt(String input, String modelId, boolean finetune) {
65111 }
66112
67113 public String summaryText (String input , String modelId , boolean finetuned ) {
68- String response = resolvePrompt ("Summarize this: \n " + input , modelId , finetuned );
114+ String response = resolvePrompt (input , modelId , finetuned , true );
69115 return response ;
70116 }
71117}
0 commit comments