Skip to content

Commit a8d9ccd

Browse files
committed
add system instruction
update models introduce HarmCategoryType update HarmCategory
1 parent 55f41c9 commit a8d9ccd

File tree

5 files changed

+98
-25
lines changed

5 files changed

+98
-25
lines changed

gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ public Stream<GeneratedContent> generateContentStream(GenerativeModel model) {
215215
var gcr = jsonParser.fromJson(line.substring(STREAM_LINE_PREFIX_LENGTH), GenerateContentResponse.class);
216216
// each element can just replace the previous one
217217
this.responseById.put(uuid, gcr);
218+
// todo catch safety error
219+
//if ("SAFETY".equals(gcr.candidates().get(0).finishReason())) {
218220
return new GeneratedContent(uuid, gcr.candidates().get(0).content().parts().get(0).text());
219221
} catch (Exception e) {
220222
throw new RuntimeException("Unexpected line:\n" + line, e);
@@ -326,7 +328,18 @@ public CompletableFuture<List<ContentEmbedding>> embedContents(
326328

327329
private static GenerateContentRequest convert(GenerativeModel model) {
328330
List<GenerationContent> generationContents = convertGenerationContents(model);
329-
return new GenerateContentRequest(model.modelName(), generationContents, model.safetySettings(), model.generationConfig());
331+
return new GenerateContentRequest(
332+
model.modelName(),
333+
generationContents,
334+
model.safetySettings(),
335+
model.generationConfig(),
336+
model.systemInstruction().isEmpty() ? null :
337+
new SystemInstruction(
338+
model.systemInstruction().stream()
339+
.map(SystemInstructionPart::new)
340+
.toList()
341+
)
342+
);
330343
}
331344

332345
private static List<GenerationContent> convertGenerationContents(GenerativeModel model) {
@@ -578,7 +591,18 @@ private record GenerateContentRequest(
578591
String model,
579592
List<GenerationContent> contents,
580593
List<SafetySetting> safetySettings,
581-
GenerationConfig generationConfig
594+
GenerationConfig generationConfig,
595+
SystemInstruction systemInstruction
596+
) {
597+
}
598+
599+
private record SystemInstruction(
600+
List<SystemInstructionPart> parts
601+
) {
602+
}
603+
604+
private record SystemInstructionPart(
605+
String text
582606
) {
583607
}
584608

gemini-api/src/main/java/swiss/ameri/gemini/api/GenerativeModel.java

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66
/**
77
* Contains all the information needed for Gemini API to generate new content.
88
*
9-
* @param modelName to be used. see {@link ModelVariant}. Must start with "models/"
10-
* @param contents given as input to Gemini API
11-
* @param safetySettings optional, to adjust safety settings
12-
* @param generationConfig optional, to configure the prompt
9+
* @param modelName to be used. see {@link ModelVariant}. Must start with "models/"
10+
* @param contents given as input to Gemini API
11+
* @param safetySettings optional, to adjust safety settings
12+
* @param generationConfig optional, to configure the prompt
13+
* @param systemInstruction optional, system instruction
1314
*/
1415
public record GenerativeModel(
1516
String modelName,
1617
List<Content> contents,
1718
List<SafetySetting> safetySettings,
18-
GenerationConfig generationConfig
19+
GenerationConfig generationConfig,
20+
List<String> systemInstruction
1921
) {
2022

2123
/**
@@ -35,6 +37,7 @@ public static class GenerativeModelBuilder {
3537
private GenerationConfig generationConfig;
3638
private final List<Content> contents = new ArrayList<>();
3739
private final List<SafetySetting> safetySettings = new ArrayList<>();
40+
private final List<String> systemInstructions = new ArrayList<>();
3841

3942
private GenerativeModelBuilder() {
4043
}
@@ -71,6 +74,17 @@ public GenerativeModelBuilder addContent(Content content) {
7174
return this;
7275
}
7376

77+
/**
78+
* Add system instruction
79+
*
80+
* @param systemInstruction to be added
81+
* @return this
82+
*/
83+
public GenerativeModelBuilder addSystemInstruction(String systemInstruction) {
84+
this.systemInstructions.add(systemInstruction);
85+
return this;
86+
}
87+
7488
/**
7589
* Add safety setting
7690
*
@@ -99,7 +113,7 @@ public GenerativeModelBuilder generationConfig(GenerationConfig generationConfig
99113
* @return a completed (not necessarily validated) {@link GenerativeModel}
100114
*/
101115
public GenerativeModel build() {
102-
return new GenerativeModel(modelName, contents, safetySettings, generationConfig);
116+
return new GenerativeModel(modelName, contents, safetySettings, generationConfig, systemInstructions);
103117
}
104118
}
105119

gemini-api/src/main/java/swiss/ameri/gemini/api/ModelVariant.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,21 @@ public enum ModelVariant {
1515
*/
1616
GEMINI_1_5_FLASH("gemini-1.5-flash"),
1717
/**
18-
* Natural language tasks, multi-turn text and code chat, and code generation.
18+
* High volume and lower intelligence tasks.
1919
*/
20-
GEMINI_1_0_PRO("gemini-1.0-pro"),
20+
GEMINI_1_5_FLASH_8B("gemini-1.5-flash-8b"),
2121
/**
22-
* Visual-related tasks, like generating image descriptions or identifying objects in images.
22+
* Natural language tasks, multi-turn text and code chat, and code generation.
2323
*/
24-
GEMINI_1_0_PRO_VISION("gemini-pro-vision"),
24+
GEMINI_1_0_PRO("gemini-1.0-pro"),
2525
/**
2626
* Measuring the relatedness of text strings.
2727
*/
2828
TEXT_EMBEDDING_004("text-embedding-004"),
29+
/**
30+
* Providing source-grounded answers to questions.
31+
*/
32+
AQA("aqa"),
2933
;
3034

3135
private final String variant;

gemini-api/src/main/java/swiss/ameri/gemini/api/SafetySetting.java

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
package swiss.ameri.gemini.api;
22

3+
import java.util.Arrays;
4+
import java.util.List;
5+
6+
import static swiss.ameri.gemini.api.SafetySetting.HarmCategoryType.*;
7+
38
/**
49
* Safety settings according to <a href="https://ai.google.dev/api/rest/v1beta/SafetySetting#harmblockthreshold">SafetySetting</a>.
510
*
@@ -29,55 +34,81 @@ public static SafetySetting of(
2934
);
3035
}
3136

37+
public enum HarmCategoryType {
38+
GEMINI,
39+
PALM,
40+
UNKNOWN
41+
}
42+
3243
/**
3344
* According to <a href="https://ai.google.dev/api/rest/v1beta/HarmCategory">HarmCategory</a>.
34-
* Currently, only the first 4 seem to be recognized as input.
45+
* See {@link #harmCategoryType} for which can be used as input to a model.
3546
*/
3647
public enum HarmCategory {
48+
3749
/**
3850
* Harasment content.
3951
*/
40-
HARM_CATEGORY_HARASSMENT,
52+
HARM_CATEGORY_HARASSMENT(GEMINI),
4153
/**
4254
* Hate speech and content.
4355
*/
44-
HARM_CATEGORY_HATE_SPEECH,
56+
HARM_CATEGORY_HATE_SPEECH(GEMINI),
4557
/**
4658
* Sexually explicit content.
4759
*/
48-
HARM_CATEGORY_SEXUALLY_EXPLICIT,
60+
HARM_CATEGORY_SEXUALLY_EXPLICIT(GEMINI),
4961
/**
5062
* Dangerous content.
5163
*/
52-
HARM_CATEGORY_DANGEROUS_CONTENT,
64+
HARM_CATEGORY_DANGEROUS_CONTENT(GEMINI),
65+
/**
66+
* Content that may be used to harm civic integrity.
67+
*/
68+
HARM_CATEGORY_CIVIC_INTEGRITY(GEMINI),
5369
/**
5470
* Category is unspecified.
5571
*/
56-
HARM_CATEGORY_UNSPECIFIED,
72+
HARM_CATEGORY_UNSPECIFIED(UNKNOWN),
5773
/**
5874
* Negative or harmful comments targeting identity and/or protected attribute.
5975
*/
60-
HARM_CATEGORY_DEROGATORY,
76+
HARM_CATEGORY_DEROGATORY(PALM),
6177
/**
6278
* Content that is rude, disrespectful, or profane.
6379
*/
64-
HARM_CATEGORY_TOXICITY,
80+
HARM_CATEGORY_TOXICITY(PALM),
6581
/**
6682
* Describes scenarios depicting violence against an individual or group, or general descriptions of gore.
6783
*/
68-
HARM_CATEGORY_VIOLENCE,
84+
HARM_CATEGORY_VIOLENCE(PALM),
6985
/**
7086
* Contains references to sexual acts or other lewd content.
7187
*/
72-
HARM_CATEGORY_SEXUAL,
88+
HARM_CATEGORY_SEXUAL(PALM),
7389
/**
7490
* Promotes unchecked medical advice.
7591
*/
76-
HARM_CATEGORY_MEDICAL,
92+
HARM_CATEGORY_MEDICAL(PALM),
7793
/**
7894
* Dangerous content that promotes, facilitates, or encourages harmful acts.
7995
*/
80-
HARM_CATEGORY_DANGEROUS,
96+
HARM_CATEGORY_DANGEROUS(PALM);
97+
private final HarmCategoryType harmCategoryType;
98+
99+
HarmCategory(HarmCategoryType harmCategoryType) {
100+
this.harmCategoryType = harmCategoryType;
101+
}
102+
103+
public HarmCategoryType harmCategoryType() {
104+
return harmCategoryType;
105+
}
106+
107+
public static List<HarmCategory> harmCategoriesFor(HarmCategoryType type) {
108+
return Arrays.stream(values())
109+
.filter(category -> category.harmCategoryType == type)
110+
.toList();
111+
}
81112
}
82113

83114
/**

gemini-tester/src/main/java/swiss/ameri/gemini/tester/GeminiTester.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ private static void listModels(GenAi genAi) {
172172
private static void textAndImage(GenAi genAi) throws IOException {
173173
System.out.println("----- text and image");
174174
var model = GenerativeModel.builder()
175-
.modelName(ModelVariant.GEMINI_1_0_PRO_VISION)
175+
.modelName(ModelVariant.GEMINI_1_5_FLASH)
176176
.addContent(
177177
Content.textAndMediaContentBuilder()
178178
.role(Content.Role.USER)

0 commit comments

Comments
 (0)