Skip to content

Commit e20ff59

Browse files
committed
add function calling
1 parent 3b89345 commit e20ff59

File tree

7 files changed

+297
-17
lines changed

7 files changed

+297
-17
lines changed

gemini-api/src/main/java/swiss/ameri/gemini/api/Content.java

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,41 @@
99
*/
1010
public sealed interface Content {
1111

12+
/**
13+
* Role belonging to this turn in the conversation.
14+
*
15+
* @return string value of a {@link Role}
16+
*/
17+
String role();
18+
19+
/**
20+
* Create a {@link FunctionCallContent}.
21+
*
22+
* @param role belonging to this turn in the conversation.
23+
* @param functionCall by the role
24+
* @return a {@link FunctionCallContent}
25+
*/
26+
static FunctionCallContent functionCallContent(
27+
Role role,
28+
FunctionCall functionCall
29+
) {
30+
return new FunctionCallContent(role == null ? null : role.roleName(), functionCall);
31+
}
32+
33+
/**
34+
* Create a {@link FunctionCallContent}.
35+
*
36+
* @param role belonging to this turn in the conversation.
37+
* @param functionResponse by the role
38+
* @return a {@link FunctionResponseContent}
39+
*/
40+
static FunctionResponseContent functionResponseContent(
41+
Role role,
42+
FunctionResponse functionResponse
43+
) {
44+
return new FunctionResponseContent(role == null ? null : role.roleName(), functionResponse);
45+
}
46+
1247
/**
1348
* Create a {@link TextContent}.
1449
*
@@ -49,6 +84,32 @@ static TextAndMediaContent.TextAndMediaContentBuilder textAndMediaContentBuilder
4984
return TextAndMediaContent.builder();
5085
}
5186

87+
/**
88+
* A predicted FunctionCall returned from the model that contains a string representing the FunctionDeclaration.name
89+
* with the arguments and their values.
90+
*
91+
* @param role belonging to this turn in the conversation. see {@link Role}
92+
* @param functionCall returned form the model
93+
*/
94+
record FunctionCallContent(
95+
String role,
96+
FunctionCall functionCall
97+
) implements Content {
98+
}
99+
100+
/**
101+
* The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name and
102+
* a structured JSON object containing any output from the function is used as context to the model.
103+
* This should contain the result of aFunctionCall made based on model prediction.
104+
*
105+
* @param role belonging to this turn in the conversation. see {@link Role}
106+
* @param functionResponse response to a function call
107+
*/
108+
record FunctionResponseContent(
109+
String role,
110+
FunctionResponse functionResponse
111+
) implements Content {
112+
}
52113

53114
/**
54115
* A part of a conversation that contains text.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package swiss.ameri.gemini.api;
2+
3+
import java.util.Map;
4+
5+
/**
6+
* A predicted FunctionCall returned from the model that contains a string representing the FunctionDeclaration.name
7+
* with the arguments and their values.
8+
*
9+
* @param name Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
10+
* @param format Optional. The function parameters and values in JSON object format.
11+
*/
12+
public record FunctionCall(
13+
String name,
14+
Map<String, ?> format
15+
) {
16+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package swiss.ameri.gemini.api;
2+
3+
/**
4+
* Structured representation of a function declaration as defined by the OpenAPI 3.03 specification.
5+
* Included in this declaration are the function name and parameters.
6+
* This FunctionDeclaration is a representation of a block of code that can be used as a Tool by the model and executed by the client.
7+
*
8+
* @param name Required. The name of the function. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
9+
* @param description Required. A brief description of the function.
10+
* @param parameters Optional. Describes the parameters to this function.
11+
* Reflects the Open API 3.03 Parameter Object string Key: the name of the parameter.
12+
* Parameter names are case-sensitive.
13+
* Schema Value: the Schema defining the type used for the parameter.
14+
*/
15+
public record FunctionDeclaration(
16+
String name,
17+
String description,
18+
Schema parameters
19+
) {
20+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package swiss.ameri.gemini.api;
2+
3+
import java.util.Map;
4+
5+
/**
6+
* The result output from a FunctionCall that contains a string representing the FunctionDeclaration.name and
7+
* a structured JSON object containing any output from the function is used as context to the model.
8+
* This should contain the result of aFunctionCall made based on model prediction.
9+
*
10+
* @param name Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 63.
11+
* @param response Required. The function response in JSON object format.
12+
*/
13+
public record FunctionResponse(
14+
String name,
15+
Map<String, ?> response
16+
) {
17+
}

gemini-api/src/main/java/swiss/ameri/gemini/api/GenAi.java

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
import java.net.http.HttpClient;
99
import java.net.http.HttpRequest;
1010
import java.net.http.HttpResponse;
11-
import java.util.*;
11+
import java.util.ArrayList;
12+
import java.util.Collection;
13+
import java.util.List;
14+
import java.util.Map;
15+
import java.util.Objects;
16+
import java.util.Optional;
17+
import java.util.UUID;
1218
import java.util.concurrent.CompletableFuture;
1319
import java.util.concurrent.ConcurrentHashMap;
1420
import java.util.stream.Collectors;
@@ -331,6 +337,10 @@ public CompletableFuture<List<ContentEmbedding>> embedContents(
331337

332338
private static GenerateContentRequest convert(GenerativeModel model) {
333339
List<GenerationContent> generationContents = convertGenerationContents(model);
340+
List<Tool> tools = new ArrayList<>();
341+
if (!model.functionDeclarations().isEmpty()) {
342+
tools.add(new Tool(model.functionDeclarations()));
343+
}
334344
return new GenerateContentRequest(
335345
model.modelName(),
336346
generationContents,
@@ -341,7 +351,9 @@ private static GenerateContentRequest convert(GenerativeModel model) {
341351
model.systemInstruction().stream()
342352
.map(SystemInstructionPart::new)
343353
.toList()
344-
)
354+
),
355+
tools.isEmpty() ? null :
356+
List.of(new Tool(model.functionDeclarations()))
345357
);
346358
}
347359

@@ -355,6 +367,8 @@ private static List<GenerationContent> convertGenerationContents(GenerativeModel
355367
List.of(
356368
new GenerationPart(
357369
textContent.text(),
370+
null,
371+
null,
358372
null
359373
)
360374
)
@@ -368,7 +382,9 @@ private static List<GenerationContent> convertGenerationContents(GenerativeModel
368382
new InlineData(
369383
imageContent.media().mimeType(),
370384
imageContent.media().mediaBase64()
371-
)
385+
),
386+
null,
387+
null
372388
)
373389
)
374390
);
@@ -379,6 +395,8 @@ private static List<GenerationContent> convertGenerationContents(GenerativeModel
379395
Stream.of(
380396
new GenerationPart(
381397
textAndImagesContent.text(),
398+
null,
399+
null,
382400
null
383401
)
384402
),
@@ -388,10 +406,36 @@ private static List<GenerationContent> convertGenerationContents(GenerativeModel
388406
new InlineData(
389407
imageData.mimeType(),
390408
imageData.mediaBase64()
391-
)
409+
),
410+
null,
411+
null
392412
))
393413
).toList()
394414
);
415+
} else if (content instanceof Content.FunctionCallContent functionCallContent) {
416+
return new GenerationContent(
417+
functionCallContent.role(),
418+
List.of(
419+
new GenerationPart(
420+
null,
421+
null,
422+
functionCallContent.functionCall(),
423+
null
424+
)
425+
)
426+
);
427+
} else if (content instanceof Content.FunctionResponseContent functionResponseContent) {
428+
return new GenerationContent(
429+
functionResponseContent.role(),
430+
List.of(
431+
new GenerationPart(
432+
null,
433+
null,
434+
null,
435+
functionResponseContent.functionResponse()
436+
)
437+
)
438+
);
395439
} else {
396440
throw new GeminiException("Unexpected content:\n" + content);
397441
}
@@ -423,11 +467,13 @@ public void close() {
423467
*
424468
* @param id the id of the request, for subsequent queries regarding metadata of the query
425469
* @param text of the generated content
470+
* @param functionCall Optional. if the model wants to call a function
426471
* @param finishReason the reason generation was finished, according to <a href="https://ai.google.dev/api/generate-content#FinishReason">FinishReason</a>
427472
*/
428473
public record GeneratedContent(
429474
UUID id,
430475
String text,
476+
FunctionCall functionCall,
431477
String finishReason
432478
) {
433479
}
@@ -555,9 +601,10 @@ private GeneratedContent parse(String body, UUID uuid) {
555601
// we assume we always get a candidate. Otherwise, there is probably something wrong with the input
556602
var candidate = gcr.candidates().get(0);
557603
if (candidate.content() == null) {
558-
return new GeneratedContent(uuid, "", candidate.finishReason());
604+
return new GeneratedContent(uuid, "", null, candidate.finishReason());
559605
}
560-
return new GeneratedContent(uuid, candidate.content().parts().get(0).text(), candidate.finishReason());
606+
GenerationPart firstPart = candidate.content().parts().get(0);
607+
return new GeneratedContent(uuid, firstPart.text(), firstPart.functionCall(), candidate.finishReason());
561608
} catch (Exception e) {
562609
throw new GeminiException("Unexpected body:\n" + body, e);
563610
}
@@ -613,7 +660,17 @@ private record GenerateContentRequest(
613660
List<GenerationContent> contents,
614661
List<SafetySetting> safetySettings,
615662
GenerationConfig generationConfig,
616-
SystemInstruction systemInstruction
663+
SystemInstruction systemInstruction,
664+
List<Tool> tools
665+
) {
666+
}
667+
668+
/**
669+
* See <a href="https://ai.google.dev/api/caching#Tool">Tool</a>
670+
*/
671+
private record Tool(
672+
List<FunctionDeclaration> functionDeclarations
673+
// still missing CodeExecution and GoogleSearchRetrieval
617674
) {
618675
}
619676

@@ -633,10 +690,15 @@ private record GenerationContent(
633690
) {
634691
}
635692

693+
/**
694+
* See <a #href="https://ai.google.dev/api/caching#Part">Part</a>
695+
*/
636696
private record GenerationPart(
637-
// contains one or the other
697+
// contains one of these
638698
String text,
639-
InlineData inline_data
699+
InlineData inline_data,
700+
FunctionCall functionCall,
701+
FunctionResponse functionResponse
640702
) {
641703
}
642704

gemini-api/src/main/java/swiss/ameri/gemini/api/GenerativeModel.java

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,20 @@
66
/**
77
* Contains all the information needed for Gemini API to generate new content.
88
*
9-
* @param modelName to be used. see {@link ModelVariant}. Must start with "models/"
10-
* @param contents given as input to Gemini API
11-
* @param safetySettings optional, to adjust safety settings
12-
* @param generationConfig optional, to configure the prompt
13-
* @param systemInstruction optional, system instruction
9+
* @param modelName to be used. see {@link ModelVariant}. Must start with "models/"
10+
* @param contents given as input to Gemini API
11+
* @param safetySettings optional, to adjust safety settings
12+
* @param generationConfig optional, to configure the prompt
13+
* @param systemInstruction optional, system instruction
14+
* @param functionDeclarations optional, functions the model may call
1415
*/
1516
public record GenerativeModel(
1617
String modelName,
1718
List<Content> contents,
1819
List<SafetySetting> safetySettings,
1920
GenerationConfig generationConfig,
20-
List<String> systemInstruction
21+
List<String> systemInstruction,
22+
List<FunctionDeclaration> functionDeclarations
2123
) {
2224

2325
/**
@@ -38,6 +40,7 @@ public static class GenerativeModelBuilder {
3840
private final List<Content> contents = new ArrayList<>();
3941
private final List<SafetySetting> safetySettings = new ArrayList<>();
4042
private final List<String> systemInstructions = new ArrayList<>();
43+
private final List<FunctionDeclaration> functionDeclarations = new ArrayList<>();
4144

4245
private GenerativeModelBuilder() {
4346
}
@@ -96,6 +99,17 @@ public GenerativeModelBuilder addSafetySetting(SafetySetting safetySetting) {
9699
return this;
97100
}
98101

102+
/**
103+
* Add function declarations
104+
*
105+
* @param functionDeclaration to be added
106+
* @return this
107+
*/
108+
public GenerativeModelBuilder addFunctionDeclaration(FunctionDeclaration functionDeclaration) {
109+
this.functionDeclarations.add(functionDeclaration);
110+
return this;
111+
}
112+
99113
/**
100114
* Set the generation config
101115
*
@@ -113,7 +127,14 @@ public GenerativeModelBuilder generationConfig(GenerationConfig generationConfig
113127
* @return a completed (not necessarily validated) {@link GenerativeModel}
114128
*/
115129
public GenerativeModel build() {
116-
return new GenerativeModel(modelName, contents, safetySettings, generationConfig, systemInstructions);
130+
return new GenerativeModel(
131+
modelName,
132+
contents,
133+
safetySettings,
134+
generationConfig,
135+
systemInstructions,
136+
functionDeclarations
137+
);
117138
}
118139
}
119140

0 commit comments

Comments
 (0)