diff --git a/firebase-vertexai/CHANGELOG.md b/firebase-vertexai/CHANGELOG.md index 2a425be615a..32228b2e628 100644 --- a/firebase-vertexai/CHANGELOG.md +++ b/firebase-vertexai/CHANGELOG.md @@ -2,7 +2,6 @@ * [feature] Added support for `title` and `publicationDate` in citations. (#6309) * [feature] Added support for `frequencyPenalty`, `presencePenalty`, and `HarmBlockMethod`. (#6309) * [changed] **Breaking Change**: Introduced `Citations` class. Now `CitationMetadata` wraps that type. (#6276) -* [changed] **Breaking Change**: Introduced `FunctionCall` and `FunctionResponse` types. Now `FunctionCallPart` and `FunctionResponsePart` wrap those types, respectively. (#6311) * [changed] **Breaking Change**: Reworked `Schema` declaration mechanism. (#6258) * [changed] **Breaking Change**: Reworked function calling mechanism to use the new `Schema` format. Function calls no longer use native types, nor include references to the actual executable code. (#6258) * [changed] **Breaking Change**: Made `totalBillableCharacters` field in `CountTokens` nullable and optional. (#6294) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt index 880615aa4f7..44ee9797202 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt @@ -21,6 +21,10 @@ import android.graphics.BitmapFactory import android.util.Base64 import com.google.firebase.vertexai.common.client.Schema import com.google.firebase.vertexai.common.shared.FileData +import com.google.firebase.vertexai.common.shared.FunctionCall +import com.google.firebase.vertexai.common.shared.FunctionCallPart +import com.google.firebase.vertexai.common.shared.FunctionResponse +import com.google.firebase.vertexai.common.shared.FunctionResponsePart import com.google.firebase.vertexai.common.shared.InlineData import com.google.firebase.vertexai.type.BlockReason import com.google.firebase.vertexai.type.Candidate @@ -30,12 +34,8 @@ import com.google.firebase.vertexai.type.Content import com.google.firebase.vertexai.type.CountTokensResponse import com.google.firebase.vertexai.type.FileDataPart import com.google.firebase.vertexai.type.FinishReason -import com.google.firebase.vertexai.type.FunctionCall -import com.google.firebase.vertexai.type.FunctionCallPart import com.google.firebase.vertexai.type.FunctionCallingConfig import com.google.firebase.vertexai.type.FunctionDeclaration -import com.google.firebase.vertexai.type.FunctionResponse -import com.google.firebase.vertexai.type.FunctionResponsePart import com.google.firebase.vertexai.type.GenerateContentResponse import com.google.firebase.vertexai.type.GenerationConfig import com.google.firebase.vertexai.type.HarmBlockMethod @@ -81,10 +81,10 @@ internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part com.google.firebase.vertexai.common.shared.InlineDataPart( InlineData(mimeType, Base64.encodeToString(inlineData, BASE_64_FLAGS)) ) - is FunctionCallPart -> - com.google.firebase.vertexai.common.shared.FunctionCallPart(functionCall.toInternal()) - is FunctionResponsePart -> - com.google.firebase.vertexai.common.shared.FunctionResponsePart(functionResponse.toInternal()) + is com.google.firebase.vertexai.type.FunctionCallPart -> + FunctionCallPart(FunctionCall(name, args)) + is com.google.firebase.vertexai.type.FunctionResponsePart -> + FunctionResponsePart(FunctionResponse(name, response)) is FileDataPart -> com.google.firebase.vertexai.common.shared.FileDataPart( FileData(mimeType = mimeType, fileUri = uri) @@ -96,12 +96,6 @@ internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part } } -internal fun FunctionCall.toInternal() = - com.google.firebase.vertexai.common.shared.FunctionCall(name, args) - -internal fun FunctionResponse.toInternal() = - com.google.firebase.vertexai.common.shared.FunctionResponse(name, response) - internal fun SafetySetting.toInternal() = com.google.firebase.vertexai.common.shared.SafetySetting( harmCategory.toInternal(), @@ -235,10 +229,16 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part { InlineDataPart(data, inlineData.mimeType) } } - is com.google.firebase.vertexai.common.shared.FunctionCallPart -> - FunctionCallPart(functionCall.toPublic()) - is com.google.firebase.vertexai.common.shared.FunctionResponsePart -> - FunctionResponsePart(functionResponse.toPublic()) + is FunctionCallPart -> + com.google.firebase.vertexai.type.FunctionCallPart( + functionCall.name, + functionCall.args.orEmpty().mapValues { it.value ?: JsonNull } + ) + is FunctionResponsePart -> + com.google.firebase.vertexai.type.FunctionResponsePart( + functionResponse.name, + functionResponse.response, + ) is com.google.firebase.vertexai.common.shared.FileDataPart -> FileDataPart(fileData.mimeType, fileData.fileUri) else -> @@ -248,15 +248,6 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part { } } -internal fun com.google.firebase.vertexai.common.shared.FunctionCall.toPublic() = - FunctionCall(name, args.orEmpty().mapValues { it.value ?: JsonNull }) - -internal fun com.google.firebase.vertexai.common.shared.FunctionResponse.toPublic() = - FunctionResponse( - name, - response, - ) - internal fun com.google.firebase.vertexai.common.server.CitationSources.toPublic(): Citation { val publicationDateAsCalendar = publicationDate?.let { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt index 9e133896b89..41ddfcfbe41 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt @@ -19,6 +19,7 @@ package com.google.firebase.vertexai.type import android.graphics.Bitmap import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonObject +import org.json.JSONObject /** Interface representing data sent to and received from requests. */ public interface Part @@ -44,35 +45,21 @@ public class ImagePart(public val image: Bitmap) : Part public class InlineDataPart(public val inlineData: ByteArray, public val mimeType: String) : Part /** - * Represents a function call request from the model - * - * @param functionCall The information provided by the model to call a function. - */ -public class FunctionCallPart(public val functionCall: FunctionCall) : Part - -/** - * The result of calling a function as requested by the model. - * - * @param functionResponse The information to send back to the model as the result of a functions - * call. - */ -public class FunctionResponsePart(public val functionResponse: FunctionResponse) : Part - -/** - * The data necessary to invoke function [name] using the arguments [args]. + * Represents function call name and params received from requests. * * @param name the name of the function to call * @param args the function parameters and values as a [Map] */ -public class FunctionCall(public val name: String, public val args: Map) +public class FunctionCallPart(public val name: String, public val args: Map) : + Part /** - * The [response] generated after calling function [name]. + * Represents function call output to be returned to the model when it requests a function call. * * @param name the name of the called function - * @param response the response produced by the function as a [JsonObject] + * @param response the response produced by the function as a [JSONObject] */ -public class FunctionResponse(public val name: String, public val response: JsonObject) +public class FunctionResponsePart(public val name: String, public val response: JsonObject) : Part /** * Represents file data stored in Cloud Storage for Firebase, referenced by URI. diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt index a71d8e16c08..b19dcd7aa1e 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt @@ -353,7 +353,7 @@ internal class UnarySnapshotTests { val response = model.generateContent("prompt") val callPart = (response.candidates.first().content.parts.first() as FunctionCallPart) - callPart.functionCall.args["season"] shouldBe JsonPrimitive(null) + callPart.args["season"] shouldBe JsonPrimitive(null) } } @@ -370,7 +370,7 @@ internal class UnarySnapshotTests { it.parts.first().shouldBeInstanceOf() } - callPart.functionCall.args["current"] shouldBe JsonPrimitive(true) + callPart.args["current"] shouldBe JsonPrimitive(true) } } @@ -387,11 +387,9 @@ internal class UnarySnapshotTests { it.parts.first().shouldBeInstanceOf() } - callPart.functionCall.args["current"] shouldBe JsonPrimitive(true) - callPart.functionCall.args["testObject"]!! - .jsonObject["testProperty"]!! - .jsonPrimitive - .content shouldBe "string property" + callPart.args["current"] shouldBe JsonPrimitive(true) + callPart.args["testObject"]!!.jsonObject["testProperty"]!!.jsonPrimitive.content shouldBe + "string property" } } @@ -402,8 +400,8 @@ internal class UnarySnapshotTests { val response = model.generateContent("prompt") val callPart = response.functionCalls.shouldNotBeEmpty().first() - callPart.functionCall.name shouldBe "current_time" - callPart.functionCall.args.isEmpty() shouldBe true + callPart.name shouldBe "current_time" + callPart.args.isEmpty() shouldBe true } } @@ -414,9 +412,9 @@ internal class UnarySnapshotTests { val response = model.generateContent("prompt") val callPart = response.functionCalls.shouldNotBeEmpty().first() - callPart.functionCall.name shouldBe "sum" - callPart.functionCall.args["x"] shouldBe JsonPrimitive(4) - callPart.functionCall.args["y"] shouldBe JsonPrimitive(5) + callPart.name shouldBe "sum" + callPart.args["x"] shouldBe JsonPrimitive(4) + callPart.args["y"] shouldBe JsonPrimitive(5) } } @@ -429,8 +427,8 @@ internal class UnarySnapshotTests { callList.size shouldBe 3 callList.forEach { - it.functionCall.name shouldBe "sum" - it.functionCall.args.size shouldBe 2 + it.name shouldBe "sum" + it.args.size shouldBe 2 } } } @@ -444,7 +442,7 @@ internal class UnarySnapshotTests { response.text shouldBe "The sum of [1, 2, 3] is" callList.size shouldBe 2 - callList.forEach { it.functionCall.args.size shouldBe 2 } + callList.forEach { it.args.size shouldBe 2 } } }