diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt index 59188094b2c..5a86ee8de4a 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt @@ -21,10 +21,6 @@ import android.graphics.BitmapFactory import android.util.Base64 import com.google.firebase.vertexai.common.client.Schema import com.google.firebase.vertexai.common.shared.FileData -import com.google.firebase.vertexai.common.shared.FunctionCall -import com.google.firebase.vertexai.common.shared.FunctionCallPart -import com.google.firebase.vertexai.common.shared.FunctionResponse -import com.google.firebase.vertexai.common.shared.FunctionResponsePart import com.google.firebase.vertexai.common.shared.InlineData import com.google.firebase.vertexai.type.BlockReason import com.google.firebase.vertexai.type.Candidate @@ -34,8 +30,12 @@ import com.google.firebase.vertexai.type.Content import com.google.firebase.vertexai.type.CountTokensResponse import com.google.firebase.vertexai.type.FileDataPart import com.google.firebase.vertexai.type.FinishReason +import com.google.firebase.vertexai.type.FunctionCall +import com.google.firebase.vertexai.type.FunctionCallPart import com.google.firebase.vertexai.type.FunctionCallingConfig import com.google.firebase.vertexai.type.FunctionDeclaration +import com.google.firebase.vertexai.type.FunctionResponse +import com.google.firebase.vertexai.type.FunctionResponsePart import com.google.firebase.vertexai.type.GenerateContentResponse import com.google.firebase.vertexai.type.GenerationConfig import com.google.firebase.vertexai.type.HarmBlockMethod @@ -59,6 +59,7 @@ import java.io.ByteArrayOutputStream import java.util.Calendar import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive import org.json.JSONObject private const val BASE_64_FLAGS = Base64.NO_WRAP @@ -80,10 +81,10 @@ internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part com.google.firebase.vertexai.common.shared.InlineDataPart( InlineData(mimeType, Base64.encodeToString(inlineData, BASE_64_FLAGS)) ) - is com.google.firebase.vertexai.type.FunctionCallPart -> - FunctionCallPart(FunctionCall(name, args.orEmpty())) - is com.google.firebase.vertexai.type.FunctionResponsePart -> - FunctionResponsePart(FunctionResponse(name, response.toInternal())) + is FunctionCallPart -> + com.google.firebase.vertexai.common.shared.FunctionCallPart(functionCall.toInternal()) + is FunctionResponsePart -> + com.google.firebase.vertexai.common.shared.FunctionResponsePart(functionResponse.toInternal()) is FileDataPart -> com.google.firebase.vertexai.common.shared.FileDataPart( FileData(mimeType = mimeType, fileUri = uri) @@ -95,6 +96,15 @@ internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part } } +internal fun FunctionCall.toInternal() = + com.google.firebase.vertexai.common.shared.FunctionCall( + name, + args.orEmpty().mapValues { it.value.toString() } + ) + +internal fun FunctionResponse.toInternal() = + com.google.firebase.vertexai.common.shared.FunctionResponse(name, response) + internal fun SafetySetting.toInternal() = com.google.firebase.vertexai.common.shared.SafetySetting( harmCategory.toInternal(), @@ -213,16 +223,10 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part { InlineDataPart(inlineData.mimeType, data) } } - is FunctionCallPart -> - com.google.firebase.vertexai.type.FunctionCallPart( - functionCall.name, - functionCall.args.orEmpty(), - ) - is FunctionResponsePart -> - com.google.firebase.vertexai.type.FunctionResponsePart( - functionResponse.name, - functionResponse.response.toPublic(), - ) + is com.google.firebase.vertexai.common.shared.FunctionCallPart -> + FunctionCallPart(functionCall.toPublic()) + is com.google.firebase.vertexai.common.shared.FunctionResponsePart -> + FunctionResponsePart(functionResponse.toPublic()) is com.google.firebase.vertexai.common.shared.FileDataPart -> FileDataPart(fileData.mimeType, fileData.fileUri) else -> @@ -232,6 +236,21 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part { } } +internal fun com.google.firebase.vertexai.common.shared.FunctionCall.toPublic() = + FunctionCall( + name, + args.orEmpty().mapValues { + val argValue = it.value + if (argValue == null) JsonPrimitive(null) else Json.parseToJsonElement(argValue) + } + ) + +internal fun com.google.firebase.vertexai.common.shared.FunctionResponse.toPublic() = + FunctionResponse( + name, + response, + ) + internal fun com.google.firebase.vertexai.common.server.CitationSources.toPublic(): Citation { val publicationDateAsCalendar = publicationDate?.let { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt index fff7eae7959..ea9635a2ef4 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt @@ -17,7 +17,8 @@ package com.google.firebase.vertexai.type import android.graphics.Bitmap -import org.json.JSONObject +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject /** Interface representing data sent to and received from requests. */ interface Part @@ -44,20 +45,35 @@ class ImagePart(val image: Bitmap) : Part class InlineDataPart(val mimeType: String, val inlineData: ByteArray) : Part /** - * Represents function call name and params received from requests. + * Represents a function call request from the model + * + * @param functionCall The information provided by the model to call a function. + */ +class FunctionCallPart(val functionCall: FunctionCall) : Part + +/** + * The result of calling a function as requested by the model. + * + * @param functionResponse The information to send back to the model as the result of a functions + * call. + */ +class FunctionResponsePart(val functionResponse: FunctionResponse) : Part + +/** + * The data necessary to invoke function [name] using the arguments [args]. * * @param name the name of the function to call * @param args the function parameters and values as a [Map] */ -class FunctionCallPart(val name: String, val args: Map) : Part +class FunctionCall(val name: String, val args: Map) /** - * Represents function call output to be returned to the model when it requests a function call. + * The [response] generated after calling function [name]. * * @param name the name of the called function - * @param response the response produced by the function as a [JSONObject] + * @param response the response produced by the function as a [JsonObject] */ -class FunctionResponsePart(val name: String, val response: JSONObject) : Part +class FunctionResponse(val name: String, val response: JsonObject) /** * Represents file data stored in Cloud Storage for Firebase, referenced by URI. diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt index a2a12f632d6..50113660208 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt @@ -44,6 +44,7 @@ import io.kotest.matchers.types.shouldBeInstanceOf import io.ktor.http.HttpStatusCode import kotlin.time.Duration.Companion.seconds import kotlinx.coroutines.withTimeout +import kotlinx.serialization.json.JsonPrimitive import org.json.JSONArray import org.junit.Test @@ -350,7 +351,7 @@ internal class UnarySnapshotTests { val response = model.generateContent("prompt") val callPart = (response.candidates.first().content.parts.first() as FunctionCallPart) - callPart.args["season"] shouldBe null + callPart.functionCall.args["season"] shouldBe JsonPrimitive(null) } } @@ -367,7 +368,7 @@ internal class UnarySnapshotTests { it.parts.first().shouldBeInstanceOf() } - callPart.args["current"] shouldBe "true" + callPart.functionCall.args["current"] shouldBe JsonPrimitive(true) } } @@ -378,8 +379,8 @@ internal class UnarySnapshotTests { val response = model.generateContent("prompt") val callPart = response.functionCalls.shouldNotBeEmpty().first() - callPart.name shouldBe "current_time" - callPart.args.isEmpty() shouldBe true + callPart.functionCall.name shouldBe "current_time" + callPart.functionCall.args.isEmpty() shouldBe true } } @@ -390,9 +391,9 @@ internal class UnarySnapshotTests { val response = model.generateContent("prompt") val callPart = response.functionCalls.shouldNotBeEmpty().first() - callPart.name shouldBe "sum" - callPart.args["x"] shouldBe "4" - callPart.args["y"] shouldBe "5" + callPart.functionCall.name shouldBe "sum" + callPart.functionCall.args["x"] shouldBe JsonPrimitive(4) + callPart.functionCall.args["y"] shouldBe JsonPrimitive(5) } } @@ -405,8 +406,8 @@ internal class UnarySnapshotTests { callList.size shouldBe 3 callList.forEach { - it.name shouldBe "sum" - it.args.size shouldBe 2 + it.functionCall.name shouldBe "sum" + it.functionCall.args.size shouldBe 2 } } } @@ -420,7 +421,7 @@ internal class UnarySnapshotTests { response.text shouldBe "The sum of [1, 2, 3] is" callList.size shouldBe 2 - callList.forEach { it.args.size shouldBe 2 } + callList.forEach { it.functionCall.args.size shouldBe 2 } } }