Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ import android.graphics.BitmapFactory
import android.util.Base64
import com.google.firebase.vertexai.common.client.Schema
import com.google.firebase.vertexai.common.shared.FileData
import com.google.firebase.vertexai.common.shared.FunctionCall
import com.google.firebase.vertexai.common.shared.FunctionCallPart
import com.google.firebase.vertexai.common.shared.FunctionResponse
import com.google.firebase.vertexai.common.shared.FunctionResponsePart
import com.google.firebase.vertexai.common.shared.InlineData
import com.google.firebase.vertexai.type.BlockReason
import com.google.firebase.vertexai.type.Candidate
Expand All @@ -34,8 +30,12 @@ import com.google.firebase.vertexai.type.Content
import com.google.firebase.vertexai.type.CountTokensResponse
import com.google.firebase.vertexai.type.FileDataPart
import com.google.firebase.vertexai.type.FinishReason
import com.google.firebase.vertexai.type.FunctionCall
import com.google.firebase.vertexai.type.FunctionCallPart
import com.google.firebase.vertexai.type.FunctionCallingConfig
import com.google.firebase.vertexai.type.FunctionDeclaration
import com.google.firebase.vertexai.type.FunctionResponse
import com.google.firebase.vertexai.type.FunctionResponsePart
import com.google.firebase.vertexai.type.GenerateContentResponse
import com.google.firebase.vertexai.type.GenerationConfig
import com.google.firebase.vertexai.type.HarmBlockMethod
Expand All @@ -59,6 +59,7 @@ import java.io.ByteArrayOutputStream
import java.util.Calendar
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonPrimitive
import org.json.JSONObject

private const val BASE_64_FLAGS = Base64.NO_WRAP
Expand All @@ -80,10 +81,10 @@ internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part
com.google.firebase.vertexai.common.shared.InlineDataPart(
InlineData(mimeType, Base64.encodeToString(inlineData, BASE_64_FLAGS))
)
is com.google.firebase.vertexai.type.FunctionCallPart ->
FunctionCallPart(FunctionCall(name, args.orEmpty()))
is com.google.firebase.vertexai.type.FunctionResponsePart ->
FunctionResponsePart(FunctionResponse(name, response.toInternal()))
is FunctionCallPart ->
com.google.firebase.vertexai.common.shared.FunctionCallPart(functionCall.toInternal())
is FunctionResponsePart ->
com.google.firebase.vertexai.common.shared.FunctionResponsePart(functionResponse.toInternal())
is FileDataPart ->
com.google.firebase.vertexai.common.shared.FileDataPart(
FileData(mimeType = mimeType, fileUri = uri)
Expand All @@ -95,6 +96,15 @@ internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part
}
}

internal fun FunctionCall.toInternal() =
com.google.firebase.vertexai.common.shared.FunctionCall(
name,
args.orEmpty().mapValues { it.value.toString() }
)

internal fun FunctionResponse.toInternal() =
com.google.firebase.vertexai.common.shared.FunctionResponse(name, response)

internal fun SafetySetting.toInternal() =
com.google.firebase.vertexai.common.shared.SafetySetting(
harmCategory.toInternal(),
Expand Down Expand Up @@ -213,16 +223,10 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part {
InlineDataPart(inlineData.mimeType, data)
}
}
is FunctionCallPart ->
com.google.firebase.vertexai.type.FunctionCallPart(
functionCall.name,
functionCall.args.orEmpty(),
)
is FunctionResponsePart ->
com.google.firebase.vertexai.type.FunctionResponsePart(
functionResponse.name,
functionResponse.response.toPublic(),
)
is com.google.firebase.vertexai.common.shared.FunctionCallPart ->
FunctionCallPart(functionCall.toPublic())
is com.google.firebase.vertexai.common.shared.FunctionResponsePart ->
FunctionResponsePart(functionResponse.toPublic())
is com.google.firebase.vertexai.common.shared.FileDataPart ->
FileDataPart(fileData.mimeType, fileData.fileUri)
else ->
Expand All @@ -232,6 +236,21 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part {
}
}

internal fun com.google.firebase.vertexai.common.shared.FunctionCall.toPublic() =
FunctionCall(
name,
args.orEmpty().mapValues {
val argValue = it.value
if (argValue == null) JsonPrimitive(null) else Json.parseToJsonElement(argValue)
}
)

internal fun com.google.firebase.vertexai.common.shared.FunctionResponse.toPublic() =
FunctionResponse(
name,
response,
)

internal fun com.google.firebase.vertexai.common.server.CitationSources.toPublic(): Citation {
val publicationDateAsCalendar =
publicationDate?.let {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
package com.google.firebase.vertexai.type

import android.graphics.Bitmap
import org.json.JSONObject
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject

/** Interface representing data sent to and received from requests. */
interface Part
Expand All @@ -44,20 +45,35 @@ class ImagePart(val image: Bitmap) : Part
class InlineDataPart(val mimeType: String, val inlineData: ByteArray) : Part

/**
* Represents function call name and params received from requests.
* Represents a function call request from the model
*
* @param functionCall The information provided by the model to call a function.
*/
class FunctionCallPart(val functionCall: FunctionCall) : Part

/**
* The result of calling a function as requested by the model.
*
* @param functionResponse The information to send back to the model as the result of a functions
* call.
*/
class FunctionResponsePart(val functionResponse: FunctionResponse) : Part

/**
* The data necessary to invoke function [name] using the arguments [args].
*
* @param name the name of the function to call
* @param args the function parameters and values as a [Map]
*/
class FunctionCallPart(val name: String, val args: Map<String, String?>) : Part
class FunctionCall(val name: String, val args: Map<String, JsonElement>)

/**
* Represents function call output to be returned to the model when it requests a function call.
* The [response] generated after calling function [name].
*
* @param name the name of the called function
* @param response the response produced by the function as a [JSONObject]
* @param response the response produced by the function as a [JsonObject]
*/
class FunctionResponsePart(val name: String, val response: JSONObject) : Part
class FunctionResponse(val name: String, val response: JsonObject)

/**
* Represents file data stored in Cloud Storage for Firebase, referenced by URI.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import io.kotest.matchers.types.shouldBeInstanceOf
import io.ktor.http.HttpStatusCode
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.json.JsonPrimitive
import org.json.JSONArray
import org.junit.Test

Expand Down Expand Up @@ -350,7 +351,7 @@ internal class UnarySnapshotTests {
val response = model.generateContent("prompt")
val callPart = (response.candidates.first().content.parts.first() as FunctionCallPart)

callPart.args["season"] shouldBe null
callPart.functionCall.args["season"] shouldBe JsonPrimitive(null)
}
}

Expand All @@ -367,7 +368,7 @@ internal class UnarySnapshotTests {
it.parts.first().shouldBeInstanceOf<FunctionCallPart>()
}

callPart.args["current"] shouldBe "true"
callPart.functionCall.args["current"] shouldBe JsonPrimitive(true)
}
}

Expand All @@ -378,8 +379,8 @@ internal class UnarySnapshotTests {
val response = model.generateContent("prompt")
val callPart = response.functionCalls.shouldNotBeEmpty().first()

callPart.name shouldBe "current_time"
callPart.args.isEmpty() shouldBe true
callPart.functionCall.name shouldBe "current_time"
callPart.functionCall.args.isEmpty() shouldBe true
}
}

Expand All @@ -390,9 +391,9 @@ internal class UnarySnapshotTests {
val response = model.generateContent("prompt")
val callPart = response.functionCalls.shouldNotBeEmpty().first()

callPart.name shouldBe "sum"
callPart.args["x"] shouldBe "4"
callPart.args["y"] shouldBe "5"
callPart.functionCall.name shouldBe "sum"
callPart.functionCall.args["x"] shouldBe JsonPrimitive(4)
callPart.functionCall.args["y"] shouldBe JsonPrimitive(5)
}
}

Expand All @@ -405,8 +406,8 @@ internal class UnarySnapshotTests {

callList.size shouldBe 3
callList.forEach {
it.name shouldBe "sum"
it.args.size shouldBe 2
it.functionCall.name shouldBe "sum"
it.functionCall.args.size shouldBe 2
}
}
}
Expand All @@ -420,7 +421,7 @@ internal class UnarySnapshotTests {

response.text shouldBe "The sum of [1, 2, 3] is"
callList.size shouldBe 2
callList.forEach { it.args.size shouldBe 2 }
callList.forEach { it.functionCall.args.size shouldBe 2 }
}
}

Expand Down
Loading