Skip to content

Commit 81f9096

Browse files
committed
prepare for 0.6.1 to return multi text output
prepare the RemoteLanguageModel to support multiple output responses.
1 parent 2bb94c8 commit 81f9096

File tree

3 files changed

+60
-10
lines changed

3 files changed

+60
-10
lines changed

core/com.intellijava.core/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
<groupId>io.github.barqawiz</groupId>
88
<artifactId>intellijava.core</artifactId>
9-
<version>0.6.0</version>
9+
<version>0.6.1</version>
1010

1111
<name>Intellijava</name>
1212
<description>IntelliJava allows java developers to easily integrate with the latest language models, image generation, and deep learning frameworks.</description>

core/com.intellijava.core/src/main/java/com/intellijava/core/controller/RemoteLanguageModel.java

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@
2121
import java.util.List;
2222
import java.util.Map;
2323
import com.intellijava.core.model.CohereLanguageResponse;
24+
import com.intellijava.core.model.CohereLanguageResponse.Generation;
2425
import com.intellijava.core.model.OpenaiLanguageResponse;
26+
import com.intellijava.core.model.OpenaiLanguageResponse.Choice;
2527
import com.intellijava.core.model.SupportedLangModels;
28+
import com.intellijava.core.model.OpenaiImageResponse.Data;
2629
import com.intellijava.core.model.input.LanguageModelInput;
2730
import com.intellijava.core.wrappers.CohereAIWrapper;
2831
import com.intellijava.core.wrappers.OpenAIWrapper;
@@ -143,11 +146,13 @@ private void initiate(String keyValue, SupportedLangModels keyType) {
143146
public String generateText(LanguageModelInput langInput) throws IOException {
144147

145148
if (this.keyType.equals(SupportedLangModels.openai)) {
146-
return this.generateOpenaiText(langInput.getModel(), langInput.getPrompt(), langInput.getTemperature(),
147-
langInput.getMaxTokens());
149+
return this.generateOpenaiText(langInput.getModel(),
150+
langInput.getPrompt(), langInput.getTemperature(),
151+
langInput.getMaxTokens(), langInput.getNumberOfOutputs()).get(0);
148152
} else if (this.keyType.equals(SupportedLangModels.cohere)) {
149-
return this.generateCohereText(langInput.getModel(), langInput.getPrompt(), langInput.getTemperature(),
150-
langInput.getMaxTokens());
153+
return this.generateCohereText(langInput.getModel(),
154+
langInput.getPrompt(), langInput.getTemperature(),
155+
langInput.getMaxTokens(), langInput.getNumberOfOutputs()).get(0);
151156
} else {
152157
throw new IllegalArgumentException("This version support openai keyType only");
153158
}
@@ -163,11 +168,13 @@ public String generateText(LanguageModelInput langInput) throws IOException {
163168
* @param prompt text of the required action or the question.
164169
* @param temperature higher values means more risks and creativity.
165170
* @param maxTokens maximum size of the model input and output.
171+
* @param numberOfOutputs number of model outputs.
166172
* @return string model response.
167173
* @throws IOException if there is an error when connecting to the OpenAI API.
168174
*
169175
*/
170-
private String generateOpenaiText(String model, String prompt, float temperature, int maxTokens)
176+
private List<String> generateOpenaiText(String model, String prompt, float temperature,
177+
int maxTokens, int numberOfOutputs)
171178
throws IOException {
172179

173180
if (model.equals(""))
@@ -178,10 +185,16 @@ private String generateOpenaiText(String model, String prompt, float temperature
178185
params.put("prompt", prompt);
179186
params.put("temperature", temperature);
180187
params.put("max_tokens", maxTokens);
188+
params.put("n", numberOfOutputs);
181189

182190
OpenaiLanguageResponse resModel = (OpenaiLanguageResponse) openaiWrapper.generateText(params);
183191

184-
return resModel.getChoices().get(0).getText();
192+
List<String> outputs = new ArrayList<>();
193+
for (Choice item : resModel.getChoices()) {
194+
outputs.add(item.getText());
195+
}
196+
197+
return outputs;
185198

186199
}
187200

@@ -192,11 +205,13 @@ private String generateOpenaiText(String model, String prompt, float temperature
192205
* @param prompt text of the required action or the question.
193206
* @param temperature higher values means more risks and creativity.
194207
* @param maxTokens maximum size of the model input and output.
208+
* @param numberOfOutputs number of model outputs.
195209
* @return string model response.
196210
* @throws IOException if there is an error when connecting to the API.
197211
*
198212
*/
199-
private String generateCohereText(String model, String prompt, float temperature, int maxTokens)
213+
private List<String> generateCohereText(String model, String prompt, float temperature,
214+
int maxTokens, int numberOfOutputs)
200215
throws IOException {
201216

202217
if (model.equals(""))
@@ -207,10 +222,16 @@ private String generateCohereText(String model, String prompt, float temperature
207222
params.put("prompt", prompt);
208223
params.put("temperature", temperature);
209224
params.put("max_tokens", maxTokens);
225+
params.put("num_generations", numberOfOutputs);
210226

211227
CohereLanguageResponse resModel = (CohereLanguageResponse) cohereWrapper.generateText(params);
212-
213-
return resModel.getGenerations().get(0).getText();
228+
229+
List<String> outputs = new ArrayList<>();
230+
for (Generation item: resModel.getGenerations()) {
231+
outputs.add(item.getText());
232+
}
233+
234+
return outputs;
214235

215236
}
216237
}

core/com.intellijava.core/src/main/java/com/intellijava/core/model/input/LanguageModelInput.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ public class LanguageModelInput {
1717
private String prompt;
1818
private float temperature;
1919
private int maxTokens;
20+
private int numberOfOutputs = 1;
2021

2122
/**
2223
* Private Constructor for the Builder.
@@ -27,6 +28,7 @@ private LanguageModelInput(Builder builder) {
2728
this.prompt = builder.prompt;
2829
this.temperature = builder.temperature;
2930
this.maxTokens = builder.maxTokens;
31+
this.maxTokens = builder.numberOfOutputs;
3032
}
3133
/**
3234
*
@@ -38,6 +40,7 @@ public static class Builder {
3840
private String prompt;
3941
private float temperature;
4042
private int maxTokens;
43+
private int numberOfOutputs = 1;
4144

4245
/**
4346
* Language input Constructor.
@@ -90,6 +93,22 @@ public Builder setMaxTokens(int maxTokens) {
9093
this.maxTokens = maxTokens;
9194
return this;
9295
}
96+
97+
/**
98+
* Setter for numberOfOutputs
99+
* @param numberOfOutputs number of model outputs, default value is 1.
100+
*
101+
* Cohere maximum value is five.
102+
*
103+
* @return instance of Builder
104+
*/
105+
public Builder setNumberOfOutputs(int numberOfOutputs) {
106+
if (this.numberOfOutputs < 0)
107+
this.numberOfOutputs = 0;
108+
109+
this.numberOfOutputs = numberOfOutputs;
110+
return this;
111+
}
93112

94113
/**
95114
* Build the final LanguageModelInput object.
@@ -130,5 +149,15 @@ public float getTemperature() {
130149
public int getMaxTokens() {
131150
return maxTokens;
132151
}
152+
153+
/**
154+
* Getter for number of model outputs.
155+
* @return numberOfOutputs
156+
*/
157+
public int getNumberOfOutputs() {
158+
return numberOfOutputs;
159+
}
160+
161+
133162
}
134163

0 commit comments

Comments
 (0)