Skip to content

Commit efb90ba

Browse files
committed
Bugfix in tool calling with role tool
1 parent c371073 commit efb90ba

File tree

2 files changed

+47
-35
lines changed

2 files changed

+47
-35
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntity.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.common.Strings;
1212
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
13+
import org.elasticsearch.core.Nullable;
1314
import org.elasticsearch.inference.UnifiedCompletionRequest;
1415
import org.elasticsearch.rest.RestStatus;
1516
import org.elasticsearch.xcontent.ToXContentObject;
@@ -60,6 +61,7 @@ public class GoogleVertexAiUnifiedChatCompletionRequestEntity implements ToXCont
6061
private static final String MODEL_ROLE = "model";
6162
private static final String ASSISTANT_ROLE = "assistant";
6263
private static final String SYSTEM_ROLE = "system";
64+
private static final String TOOL_ROLE = "tool";
6365
private static final String STOP_SEQUENCES = "stopSequences";
6466

6567
private static final String SYSTEM_INSTRUCTION = "systemInstruction";
@@ -76,6 +78,9 @@ private String messageRoleToGoogleVertexAiSupportedRole(String messageRole) {
7678
} else if (messageRole.equals(ASSISTANT_ROLE)) {
7779
// Gemini VertexAI API does not use "assistant". Instead, it uses "model"
7880
return MODEL_ROLE;
81+
} else if (messageRole.equals(TOOL_ROLE)) {
82+
// Gemini VertexAI does not have the tool role, so we map it to "model"
83+
return MODEL_ROLE;
7984
}
8085

8186
var errorMessage = format(

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/request/GoogleVertexAiUnifiedChatCompletionRequestEntityTests.java

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -447,46 +447,53 @@ public void testError_UnsupportedContentObjectType() throws IOException {
447447
public void testParseAllFields() throws IOException {
448448
String requestJson = """
449449
{
450-
"contents": [
450+
"contents": [
451+
{
452+
"role": "user",
453+
"parts": [
451454
{
452-
"role": "user",
453-
"parts": [
454-
{ "text": "some text" },
455-
{ "functionCall" : {
456-
"name": "get_delivery_date",
457-
"args": {
458-
"order_id" : "order_12345"
459-
}
460-
}
461-
}
462-
]
463-
}
464-
],
465-
"generationConfig": {
466-
"stopSequences": ["stop"],
467-
"temperature": 0.1,
468-
"maxOutputTokens": 100,
469-
"topP": 0.2
470-
},
471-
"tools": [
455+
"text": "some text"
456+
},
472457
{
473-
"functionDeclarations": [
474-
{
475-
"name": "get_current_weather",
476-
"description": "Get the current weather in a given location",
477-
"parameters": {
478-
"type": "object"
479-
}
480-
}
481-
]
458+
"functionCall": {
459+
"name": "get_delivery_date",
460+
"args": {
461+
"order_id": "order_12345"
462+
}
463+
}
482464
}
465+
]
466+
}
467+
],
468+
"generationConfig": {
469+
"stopSequences": [
470+
"stop"
483471
],
484-
"toolConfig": {
485-
"functionCallingConfig" : {
486-
"mode": "ANY",
487-
"allowedFunctionNames": [ "some function" ]
472+
"temperature": 0.1,
473+
"maxOutputTokens": 100,
474+
"topP": 0.2
475+
},
476+
"tools": [
477+
{
478+
"functionDeclarations": [
479+
{
480+
"name": "get_current_weather",
481+
"description": "Get the current weather in a given location",
482+
"parameters": {
483+
"type": "object"
484+
}
488485
}
486+
]
489487
}
488+
],
489+
"toolConfig": {
490+
"functionCallingConfig": {
491+
"mode": "ANY",
492+
"allowedFunctionNames": [
493+
"some function"
494+
]
495+
}
496+
}
490497
}
491498
""";
492499

@@ -561,7 +568,7 @@ public void testParseFunctionCallNoContent() throws IOException {
561568
List.of(
562569
new UnifiedCompletionRequest.Message(
563570
null,
564-
"assistant",
571+
"tool",
565572
"100",
566573
List.of(
567574
new UnifiedCompletionRequest.ToolCall(

0 commit comments

Comments
 (0)