2121import java .util .List ;
2222import java .util .Map ;
2323import com .intellijava .core .model .CohereLanguageResponse ;
24+ import com .intellijava .core .model .CohereLanguageResponse .Generation ;
2425import com .intellijava .core .model .OpenaiLanguageResponse ;
26+ import com .intellijava .core .model .OpenaiLanguageResponse .Choice ;
2527import com .intellijava .core .model .SupportedLangModels ;
28+ import com .intellijava .core .model .OpenaiImageResponse .Data ;
2629import com .intellijava .core .model .input .LanguageModelInput ;
2730import com .intellijava .core .wrappers .CohereAIWrapper ;
2831import com .intellijava .core .wrappers .OpenAIWrapper ;
@@ -143,11 +146,13 @@ private void initiate(String keyValue, SupportedLangModels keyType) {
143146 public String generateText (LanguageModelInput langInput ) throws IOException {
144147
145148 if (this .keyType .equals (SupportedLangModels .openai )) {
146- return this .generateOpenaiText (langInput .getModel (), langInput .getPrompt (), langInput .getTemperature (),
147- langInput .getMaxTokens ());
149+ return this .generateOpenaiText (langInput .getModel (),
150+ langInput .getPrompt (), langInput .getTemperature (),
151+ langInput .getMaxTokens (), langInput .getNumberOfOutputs ()).get (0 );
148152 } else if (this .keyType .equals (SupportedLangModels .cohere )) {
149- return this .generateCohereText (langInput .getModel (), langInput .getPrompt (), langInput .getTemperature (),
150- langInput .getMaxTokens ());
153+ return this .generateCohereText (langInput .getModel (),
154+ langInput .getPrompt (), langInput .getTemperature (),
155+ langInput .getMaxTokens (), langInput .getNumberOfOutputs ()).get (0 );
151156 } else {
152157 throw new IllegalArgumentException ("This version support openai keyType only" );
153158 }
@@ -163,11 +168,13 @@ public String generateText(LanguageModelInput langInput) throws IOException {
163168 * @param prompt text of the required action or the question.
164169 * @param temperature higher values means more risks and creativity.
165170 * @param maxTokens maximum size of the model input and output.
171+ * @param numberOfOutputs number of model outputs.
166172 * @return string model response.
167173 * @throws IOException if there is an error when connecting to the OpenAI API.
168174 *
169175 */
170- private String generateOpenaiText (String model , String prompt , float temperature , int maxTokens )
176+ private List <String > generateOpenaiText (String model , String prompt , float temperature ,
177+ int maxTokens , int numberOfOutputs )
171178 throws IOException {
172179
173180 if (model .equals ("" ))
@@ -178,10 +185,16 @@ private String generateOpenaiText(String model, String prompt, float temperature
178185 params .put ("prompt" , prompt );
179186 params .put ("temperature" , temperature );
180187 params .put ("max_tokens" , maxTokens );
188+ params .put ("n" , numberOfOutputs );
181189
182190 OpenaiLanguageResponse resModel = (OpenaiLanguageResponse ) openaiWrapper .generateText (params );
183191
184- return resModel .getChoices ().get (0 ).getText ();
192+ List <String > outputs = new ArrayList <>();
193+ for (Choice item : resModel .getChoices ()) {
194+ outputs .add (item .getText ());
195+ }
196+
197+ return outputs ;
185198
186199 }
187200
@@ -192,11 +205,13 @@ private String generateOpenaiText(String model, String prompt, float temperature
192205 * @param prompt text of the required action or the question.
193206 * @param temperature higher values means more risks and creativity.
194207 * @param maxTokens maximum size of the model input and output.
208+ * @param numberOfOutputs number of model outputs.
195209 * @return string model response.
196210 * @throws IOException if there is an error when connecting to the API.
197211 *
198212 */
199- private String generateCohereText (String model , String prompt , float temperature , int maxTokens )
213+ private List <String > generateCohereText (String model , String prompt , float temperature ,
214+ int maxTokens , int numberOfOutputs )
200215 throws IOException {
201216
202217 if (model .equals ("" ))
@@ -207,10 +222,16 @@ private String generateCohereText(String model, String prompt, float temperature
207222 params .put ("prompt" , prompt );
208223 params .put ("temperature" , temperature );
209224 params .put ("max_tokens" , maxTokens );
225+ params .put ("num_generations" , numberOfOutputs );
210226
211227 CohereLanguageResponse resModel = (CohereLanguageResponse ) cohereWrapper .generateText (params );
212-
213- return resModel .getGenerations ().get (0 ).getText ();
228+
229+ List <String > outputs = new ArrayList <>();
230+ for (Generation item : resModel .getGenerations ()) {
231+ outputs .add (item .getText ());
232+ }
233+
234+ return outputs ;
214235
215236 }
216237}
0 commit comments