diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/shared/Types.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/shared/Types.kt index f6c1bc22b88..c34af862def 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/shared/Types.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/shared/Types.kt @@ -62,7 +62,7 @@ internal data class FunctionResponsePart(val functionResponse: FunctionResponse) @Serializable internal data class FunctionResponse(val name: String, val response: JsonObject) @Serializable -internal data class FunctionCall(val name: String, val args: Map? = null) +internal data class FunctionCall(val name: String, val args: Map? = null) @Serializable internal data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt index 5a86ee8de4a..ed846375d23 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt @@ -58,8 +58,8 @@ import com.google.firebase.vertexai.type.content import java.io.ByteArrayOutputStream import java.util.Calendar import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonNull import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.JsonPrimitive import org.json.JSONObject private const val BASE_64_FLAGS = Base64.NO_WRAP @@ -97,10 +97,7 @@ internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part } internal fun FunctionCall.toInternal() = - com.google.firebase.vertexai.common.shared.FunctionCall( - name, - args.orEmpty().mapValues { it.value.toString() } - ) + com.google.firebase.vertexai.common.shared.FunctionCall(name, args) internal fun FunctionResponse.toInternal() = com.google.firebase.vertexai.common.shared.FunctionResponse(name, response) @@ -237,13 +234,7 @@ internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part { } internal fun com.google.firebase.vertexai.common.shared.FunctionCall.toPublic() = - FunctionCall( - name, - args.orEmpty().mapValues { - val argValue = it.value - if (argValue == null) JsonPrimitive(null) else Json.parseToJsonElement(argValue) - } - ) + FunctionCall(name, args.orEmpty().mapValues { it.value ?: JsonNull }) internal fun com.google.firebase.vertexai.common.shared.FunctionResponse.toPublic() = FunctionResponse( diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt index 50113660208..a71d8e16c08 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/UnarySnapshotTests.kt @@ -45,6 +45,8 @@ import io.ktor.http.HttpStatusCode import kotlin.time.Duration.Companion.seconds import kotlinx.coroutines.withTimeout import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive import org.json.JSONArray import org.junit.Test @@ -372,6 +374,27 @@ internal class UnarySnapshotTests { } } + @Test + fun `function call with complex json literal parses correctly`() = + goldenUnaryFile("unary-success-function-call-complex-json-literal.json") { + withTimeout(testTimeout) { + val response = model.generateContent("prompt") + val content = response.candidates.shouldNotBeNullOrEmpty().first().content + val callPart = + content.let { + it.shouldNotBeNull() + it.parts.shouldNotBeEmpty() + it.parts.first().shouldBeInstanceOf() + } + + callPart.functionCall.args["current"] shouldBe JsonPrimitive(true) + callPart.functionCall.args["testObject"]!! + .jsonObject["testProperty"]!! + .jsonPrimitive + .content shouldBe "string property" + } + } + @Test fun `function call contains no arguments`() = goldenUnaryFile("unary-success-function-call-no-arguments.json") { diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt index b08eb104248..6749a31e952 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt @@ -36,6 +36,7 @@ import io.ktor.http.HttpStatusCode import kotlin.time.Duration.Companion.seconds import kotlinx.coroutines.withTimeout import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonPrimitive import org.junit.Test @Serializable internal data class MountainColors(val name: String, val colors: List) @@ -330,7 +331,7 @@ internal class UnarySnapshotTests { } callPart.functionCall.args shouldNotBe null - callPart.functionCall.args?.get("current") shouldBe "true" + callPart.functionCall.args?.get("current") shouldBe JsonPrimitive(true) } }