From c876aad0545b3d9b53437a22c70e769a4025db57 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Fri, 31 May 2024 15:52:23 -0700 Subject: [PATCH 1/6] make args nullable in parsing and add test, also specify that we expect it to not be null --- .../generativeai/common/client/Types.kt | 1 + .../generativeai/common/shared/Types.kt | 2 +- .../generativeai/common/UnarySnapshotTests.kt | 12 ++++ .../unary/success-function-call-null.json | 57 +++++++++++++++++++ .../generativeai/internal/util/conversions.kt | 2 + .../generativeai/type/FunctionDeclarations.kt | 13 +++-- .../ai/client/generativeai/type/Part.kt | 2 +- 7 files changed, 82 insertions(+), 7 deletions(-) create mode 100644 common/src/test/resources/golden-files/unary/success-function-call-null.json diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt index cb2fb478..c8fc3fb6 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/client/Types.kt @@ -59,6 +59,7 @@ data class Schema( val type: String, val description: String? = null, val format: String? = null, + val nullable: Boolean? = false, val enum: List? = null, val properties: Map? = null, val required: List? = null, diff --git a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt index 29882560..64e816e2 100644 --- a/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt +++ b/common/src/main/kotlin/com/google/ai/client/generativeai/common/shared/Types.kt @@ -59,7 +59,7 @@ data class Content(@EncodeDefault val role: String? = "user", val parts: List) +@Serializable data class FunctionCall(val name: String, val args: Map) @Serializable data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part diff --git a/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt index a3f8d77f..13ec5428 100644 --- a/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt +++ b/common/src/test/java/com/google/ai/client/generativeai/common/UnarySnapshotTests.kt @@ -20,6 +20,7 @@ import com.google.ai.client.generativeai.common.server.BlockReason import com.google.ai.client.generativeai.common.server.FinishReason import com.google.ai.client.generativeai.common.server.HarmProbability import com.google.ai.client.generativeai.common.server.HarmSeverity +import com.google.ai.client.generativeai.common.shared.FunctionCallPart import com.google.ai.client.generativeai.common.shared.HarmCategory import com.google.ai.client.generativeai.common.shared.TextPart import com.google.ai.client.generativeai.common.util.goldenUnaryFile @@ -301,4 +302,15 @@ internal class UnarySnapshotTests { } } } + + @Test + fun `function call contains null param`() = + goldenUnaryFile("success-function-call-null.json") { + withTimeout(testTimeout) { + val response = apiController.generateContent(textGenerateContentRequest("prompt")) + val callPart = (response.candidates!!.first().content!!.parts.first() as FunctionCallPart) + + callPart.functionCall.args["season"] shouldBe null + } + } } diff --git a/common/src/test/resources/golden-files/unary/success-function-call-null.json b/common/src/test/resources/golden-files/unary/success-function-call-null.json new file mode 100644 index 00000000..25a696e6 --- /dev/null +++ b/common/src/test/resources/golden-files/unary/success-function-call-null.json @@ -0,0 +1,57 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "functionName", + "args": { + "cast": "String", + "brief_description": "String", + "original_title": "String", + "season": null, + "year": "1999", + "rating": "8.7", + "genres": "String", + "country": "САЩ", + "country_origin_english": "US", + "more_info": "String", + "title": "String", + "duration": "136", + "types": "String", + "complete_description": "String" + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 774, + "candidatesTokenCount": 4176, + "totalTokenCount": 4950 + } +} diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt index dadc2611..dfa8c3c0 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/internal/util/conversions.kt @@ -150,6 +150,7 @@ internal fun FunctionDeclaration.toInternal() = properties = getParameters().associate { it.name to it.toInternal() }, required = getParameters().map { it.name }, type = "OBJECT", + nullable = false, ), ) @@ -158,6 +159,7 @@ internal fun com.google.ai.client.generativeai.type.Schema.toInternal(): type.name, description, format, + nullable, enum, properties?.mapValues { it.value.toInternal() }, required, diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index 80569ad7..37051ce5 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -173,6 +173,7 @@ class Schema( val name: String, val description: String, val format: String? = null, + val nullable: Boolean? = null, val enum: List? = null, val properties: Map>? = null, val required: List? = null, @@ -184,19 +185,19 @@ class Schema( companion object { /** Registers a schema for an integer number */ fun int(name: String, description: String) = - Schema(name = name, description = description, type = FunctionType.INTEGER) + Schema(name = name, description = description, type = FunctionType.INTEGER, nullable = false) /** Registers a schema for a string */ fun str(name: String, description: String) = - Schema(name = name, description = description, type = FunctionType.STRING) + Schema(name = name, description = description, type = FunctionType.STRING, nullable = false) /** Registers a schema for a boolean */ fun bool(name: String, description: String) = - Schema(name = name, description = description, type = FunctionType.BOOLEAN) + Schema(name = name, description = description, type = FunctionType.BOOLEAN, nullable = false) /** Registers a schema for a floating point number */ fun num(name: String, description: String) = - Schema(name = name, description = description, type = FunctionType.NUMBER) + Schema(name = name, description = description, type = FunctionType.NUMBER, nullable = false) /** * Registers a schema for a complex object. In a function it will be returned as a [JSONObject] @@ -208,11 +209,12 @@ class Schema( type = FunctionType.OBJECT, required = contents.map { it.name }, properties = contents.associateBy { it.name }.toMap(), + nullable = false, ) /** Registers a schema for an array */ fun arr(name: String, description: String) = - Schema>(name = name, description = description, type = FunctionType.ARRAY) + Schema>(name = name, description = description, type = FunctionType.ARRAY, nullable = false) /** Registers a schema for an enum */ fun enum(name: String, description: String, values: List) = @@ -222,6 +224,7 @@ class Schema( format = "enum", enum = values, type = FunctionType.STRING, + nullable = false, ) } } diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt index 4a65da12..b72d99a0 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/Part.kt @@ -49,7 +49,7 @@ class FileDataPart(val uri: String, val mimeType: String) : Part fun Part.asFileDataPartOrNull(): FileDataPart? = this as? FileDataPart /** Represents function call name and params received from requests. */ -class FunctionCallPart(val name: String, val args: Map) : Part +class FunctionCallPart(val name: String, val args: Map) : Part /** Represents function call output to be returned to the model when it requests a function call */ class FunctionResponsePart(val name: String, val response: JSONObject) : Part From 039b66c502dec06c043ed1dc1ce4cf258b53e0dc Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Fri, 31 May 2024 15:53:01 -0700 Subject: [PATCH 2/6] ktfmt --- .../generativeai/type/FunctionDeclarations.kt | 35 ++++++++++++++++--- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt index 37051ce5..21d02b08 100644 --- a/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt +++ b/generativeai/src/main/java/com/google/ai/client/generativeai/type/FunctionDeclarations.kt @@ -185,19 +185,39 @@ class Schema( companion object { /** Registers a schema for an integer number */ fun int(name: String, description: String) = - Schema(name = name, description = description, type = FunctionType.INTEGER, nullable = false) + Schema( + name = name, + description = description, + type = FunctionType.INTEGER, + nullable = false, + ) /** Registers a schema for a string */ fun str(name: String, description: String) = - Schema(name = name, description = description, type = FunctionType.STRING, nullable = false) + Schema( + name = name, + description = description, + type = FunctionType.STRING, + nullable = false, + ) /** Registers a schema for a boolean */ fun bool(name: String, description: String) = - Schema(name = name, description = description, type = FunctionType.BOOLEAN, nullable = false) + Schema( + name = name, + description = description, + type = FunctionType.BOOLEAN, + nullable = false, + ) /** Registers a schema for a floating point number */ fun num(name: String, description: String) = - Schema(name = name, description = description, type = FunctionType.NUMBER, nullable = false) + Schema( + name = name, + description = description, + type = FunctionType.NUMBER, + nullable = false, + ) /** * Registers a schema for a complex object. In a function it will be returned as a [JSONObject] @@ -214,7 +234,12 @@ class Schema( /** Registers a schema for an array */ fun arr(name: String, description: String) = - Schema>(name = name, description = description, type = FunctionType.ARRAY, nullable = false) + Schema>( + name = name, + description = description, + type = FunctionType.ARRAY, + nullable = false, + ) /** Registers a schema for an enum */ fun enum(name: String, description: String, values: List) = From 87e6270b258ba2e88b969c82903d8c59af46d913 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Fri, 31 May 2024 15:55:53 -0700 Subject: [PATCH 3/6] update test file --- .../unary/success-function-call-null.json | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/common/src/test/resources/golden-files/unary/success-function-call-null.json b/common/src/test/resources/golden-files/unary/success-function-call-null.json index 25a696e6..14801eef 100644 --- a/common/src/test/resources/golden-files/unary/success-function-call-null.json +++ b/common/src/test/resources/golden-files/unary/success-function-call-null.json @@ -7,20 +7,8 @@ "functionCall": { "name": "functionName", "args": { - "cast": "String", - "brief_description": "String", "original_title": "String", - "season": null, - "year": "1999", - "rating": "8.7", - "genres": "String", - "country": "САЩ", - "country_origin_english": "US", - "more_info": "String", - "title": "String", - "duration": "136", - "types": "String", - "complete_description": "String" + "season": null } } } From f9b9c7c1994dd57652e800bbc93df17ce13e2b50 Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Fri, 31 May 2024 16:35:16 -0700 Subject: [PATCH 4/6] new test --- generativeai/build.gradle.kts | 1 + .../generativeai/FunctionCallingTests.kt | 61 +++++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 generativeai/src/test/java/com/google/ai/client/generativeai/FunctionCallingTests.kt diff --git a/generativeai/build.gradle.kts b/generativeai/build.gradle.kts index ba88cefb..7a8dfebd 100644 --- a/generativeai/build.gradle.kts +++ b/generativeai/build.gradle.kts @@ -89,6 +89,7 @@ dependencies { testImplementation("io.kotest:kotest-assertions-core:5.5.5") testImplementation("io.kotest:kotest-assertions-core-jvm:5.5.5") testImplementation("io.mockk:mockk:1.12.8") + testImplementation("org.json:json:20240303") androidTestImplementation("androidx.test.ext:junit:1.1.5") androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/FunctionCallingTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/FunctionCallingTests.kt new file mode 100644 index 00000000..78109bc7 --- /dev/null +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/FunctionCallingTests.kt @@ -0,0 +1,61 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.ai.client.generativeai + +import com.google.ai.client.generativeai.type.FunctionCallPart +import com.google.ai.client.generativeai.type.Schema +import com.google.ai.client.generativeai.type.Tool +import com.google.ai.client.generativeai.type.defineFunction +import io.kotest.assertions.throwables.shouldThrowWithMessage +import io.kotest.matchers.shouldBe +import org.json.JSONObject +import org.junit.Test + +internal class FunctionCallingTests { + + @Test + fun `function calls with valid args should succeed`() = doBlocking { + val functionDeclaration = + defineFunction("name", "description", Schema.str("param1", "description")) { param1 -> + JSONObject(mapOf("result" to "success")) + } + val model = GenerativeModel("model", "key", tools = listOf(Tool(listOf(functionDeclaration)))) + + val functionCall = FunctionCallPart("name", mapOf("param1" to "valid parameter")) + + val result = model.executeFunction(functionCall) + + result["result"] shouldBe "success" + } + + @Test + fun `function calls with invalid args should fail`() = doBlocking { + val functionDeclaration = + defineFunction("name", "description", Schema.str("param1", "description")) { param1 -> + JSONObject(mapOf("result" to "success")) + } + val model = GenerativeModel("model", "key", tools = listOf(Tool(listOf(functionDeclaration)))) + + val functionCall = FunctionCallPart("name", mapOf("param1" to null)) + + shouldThrowWithMessage( + "Missing argument for parameter \"param1\" for function \"name\"" + ) { + model.executeFunction(functionCall) + } + } +} From 309e070d23674174320c925e37a93a1339a048ee Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Mon, 3 Jun 2024 14:45:04 -0700 Subject: [PATCH 5/6] add extra schema to the test --- .../generativeai/FunctionCallingTests.kt | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/FunctionCallingTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/FunctionCallingTests.kt index 78109bc7..5f38fc91 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/FunctionCallingTests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/FunctionCallingTests.kt @@ -30,12 +30,28 @@ internal class FunctionCallingTests { @Test fun `function calls with valid args should succeed`() = doBlocking { val functionDeclaration = - defineFunction("name", "description", Schema.str("param1", "description")) { param1 -> + defineFunction( + "name", + "description", + Schema.str("param1", "description"), + Schema.num("param2", "description"), + Schema.bool("param3", "description"), + Schema.int("param4", "description"), + ) { param1, param2, param3, param4 -> JSONObject(mapOf("result" to "success")) } val model = GenerativeModel("model", "key", tools = listOf(Tool(listOf(functionDeclaration)))) - val functionCall = FunctionCallPart("name", mapOf("param1" to "valid parameter")) + val functionCall = + FunctionCallPart( + "name", + mapOf( + ("param1" to "valid parameter"), + ("param2" to "2.2"), + ("param3" to "false"), + ("param4" to "2"), + ) + ) val result = model.executeFunction(functionCall) From 9a726c4d865f083663ce13d851fdc1fa6d75af3c Mon Sep 17 00:00:00 2001 From: David Motsonashvili Date: Mon, 3 Jun 2024 14:52:23 -0700 Subject: [PATCH 6/6] ktfmt --- .../com/google/ai/client/generativeai/FunctionCallingTests.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/generativeai/src/test/java/com/google/ai/client/generativeai/FunctionCallingTests.kt b/generativeai/src/test/java/com/google/ai/client/generativeai/FunctionCallingTests.kt index 5f38fc91..0f862fe5 100644 --- a/generativeai/src/test/java/com/google/ai/client/generativeai/FunctionCallingTests.kt +++ b/generativeai/src/test/java/com/google/ai/client/generativeai/FunctionCallingTests.kt @@ -50,7 +50,7 @@ internal class FunctionCallingTests { ("param2" to "2.2"), ("param3" to "false"), ("param4" to "2"), - ) + ), ) val result = model.executeFunction(functionCall)