diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/util/serialization.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/util/serialization.kt index e64bb89afc3..4a2570b82d6 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/util/serialization.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/util/serialization.kt @@ -23,6 +23,7 @@ import kotlinx.serialization.KSerializer import kotlinx.serialization.SerialName import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.descriptors.buildClassSerialDescriptor +import kotlinx.serialization.descriptors.element import kotlinx.serialization.encoding.Decoder import kotlinx.serialization.encoding.Encoder @@ -36,7 +37,12 @@ import kotlinx.serialization.encoding.Encoder */ internal class FirstOrdinalSerializer>(private val enumClass: KClass) : KSerializer { - override val descriptor: SerialDescriptor = buildClassSerialDescriptor("FirstOrdinalSerializer") + override val descriptor: SerialDescriptor = + buildClassSerialDescriptor("FirstOrdinalSerializer") { + for (enumValue in enumClass.enumValues()) { + element(enumValue.toString()) + } + } override fun deserialize(decoder: Decoder): T { val name = decoder.decodeString() diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/SerializationTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/SerializationTests.kt new file mode 100644 index 00000000000..cf6a40680e5 --- /dev/null +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/SerializationTests.kt @@ -0,0 +1,215 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.vertexai + +import com.google.firebase.vertexai.common.util.descriptorToJson +import com.google.firebase.vertexai.type.Candidate +import com.google.firebase.vertexai.type.CountTokensResponse +import com.google.firebase.vertexai.type.GenerateContentResponse +import com.google.firebase.vertexai.type.ModalityTokenCount +import com.google.firebase.vertexai.type.Schema +import io.kotest.assertions.json.shouldEqualJson +import org.junit.Test + +internal class SerializationTests { + @Test + fun `test countTokensResponse serialization as Json`() { + val expectedJsonAsString = + """ + { + "id": "CountTokensResponse", + "type": "object", + "properties": { + "totalTokens": { + "type": "integer" + }, + "totalBillableCharacters": { + "type": "integer" + }, + "promptTokensDetails": { + "type": "array", + "items": { + "${'$'}ref": "ModalityTokenCount" + } + } + } + } + """ + .trimIndent() + val actualJson = descriptorToJson(CountTokensResponse.Internal.serializer().descriptor) + expectedJsonAsString shouldEqualJson actualJson.toString() + } + + @Test + fun `test modalityTokenCount serialization as Json`() { + val expectedJsonAsString = + """ + { + "id": "ModalityTokenCount", + "type": "object", + "properties": { + "modality": { + "type": "string", + "enum": [ + "UNSPECIFIED", + "TEXT", + "IMAGE", + "VIDEO", + "AUDIO", + "DOCUMENT" + ] + }, + "tokenCount": { + "type": "integer" + } + } + } + """ + .trimIndent() + val actualJson = descriptorToJson(ModalityTokenCount.Internal.serializer().descriptor) + expectedJsonAsString shouldEqualJson actualJson.toString() + } + + @Test + fun `test GenerateContentResponse serialization as Json`() { + val expectedJsonAsString = + """ + { + "id": "GenerateContentResponse", + "type": "object", + "properties": { + "candidates": { + "type": "array", + "items": { + "${'$'}ref": "Candidate" + } + }, + "promptFeedback": { + "${'$'}ref": "PromptFeedback" + }, + "usageMetadata": { + "${'$'}ref": "UsageMetadata" + } + } + } + """ + .trimIndent() + val actualJson = descriptorToJson(GenerateContentResponse.Internal.serializer().descriptor) + expectedJsonAsString shouldEqualJson actualJson.toString() + } + + @Test + fun `test Candidate serialization as Json`() { + val expectedJsonAsString = + """ + { + "id": "Candidate", + "type": "object", + "properties": { + "content": { + "${'$'}ref": "Content" + }, + "finishReason": { + "type": "string", + "enum": [ + "UNKNOWN", + "UNSPECIFIED", + "STOP", + "MAX_TOKENS", + "SAFETY", + "RECITATION", + "OTHER", + "BLOCKLIST", + "PROHIBITED_CONTENT", + "SPII", + "MALFORMED_FUNCTION_CALL" + ] + }, + "safetyRatings": { + "type": "array", + "items": { + "${'$'}ref": "SafetyRating" + } + }, + "citationMetadata": { + "${'$'}ref": "CitationMetadata" + }, + "groundingMetadata": { + "${'$'}ref": "GroundingMetadata" + } + } + } + """ + .trimIndent() + val actualJson = descriptorToJson(Candidate.Internal.serializer().descriptor) + expectedJsonAsString shouldEqualJson actualJson.toString() + } + + @Test + fun `test Schema serialization as Json`() { + /** + * Unlike the actual schema in the background, we don't represent "type" as an enum, but rather + * as a string. This is because we restrict what values can be used (using helper methods, + * rather than type). + */ + val expectedJsonAsString = + """ + { + "id": "Schema", + "type": "object", + "properties": { + "type": { + "type": "string" + }, + "format": { + "type": "string" + }, + "description": { + "type": "string" + }, + "nullable": { + "type": "boolean" + }, + "items": { + "${'$'}ref": "Schema" + }, + "enum": { + "type": "array", + "items": { + "type": "string" + } + }, + "properties": { + "type": "object", + "additionalProperties": { + "${'$'}ref": "Schema" + } + }, + "required": { + "type": "array", + "items": { + "type": "string" + } + } + } + } + """ + .trimIndent() + val actualJson = descriptorToJson(Schema.Internal.serializer().descriptor) + expectedJsonAsString shouldEqualJson actualJson.toString() + } +} diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/descriptorToJson.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/descriptorToJson.kt new file mode 100644 index 00000000000..31d9156bc75 --- /dev/null +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/descriptorToJson.kt @@ -0,0 +1,167 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.vertexai.common.util + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.SerialKind +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.descriptors.elementDescriptors +import kotlinx.serialization.descriptors.elementNames +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonObjectBuilder +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import kotlinx.serialization.json.putJsonObject + +/** + * Returns a [JsonObject] representing the classes in the hierarchy of a serialization [descriptor]. + * + * The format of the JSON object is similar to that of a Discovery Document, but restricted to these + * fields: + * - id + * - type + * - properties + * - items + * - $ref + * + * @param descriptor The [SerialDescriptor] to process. + */ +@OptIn(ExperimentalSerializationApi::class) +internal fun descriptorToJson(descriptor: SerialDescriptor): JsonObject { + return buildJsonObject { + put("id", simpleNameFromSerialName(descriptor.serialName)) + put("type", typeNameFromKind(descriptor.kind)) + if (descriptor.kind != StructureKind.CLASS) { + throw UnsupportedOperationException("Only classes can be serialized to JSON for now.") + } + // For top-level enums, add them directly. + if (descriptor.serialName == "FirstOrdinalSerializer") { + addEnumDescription(descriptor) + } else { + addObjectProperties(descriptor) + } + } +} + +@OptIn(ExperimentalSerializationApi::class) +internal fun JsonObjectBuilder.addListDescription(descriptor: SerialDescriptor) = + putJsonObject("items") { + val itemDescriptor = descriptor.elementDescriptors.first() + val nestedIsPrimitive = (descriptor.elementsCount == 1 && itemDescriptor.kind is PrimitiveKind) + if (nestedIsPrimitive) { + put("type", typeNameFromKind(itemDescriptor.kind)) + } else { + put("\$ref", simpleNameFromSerialName(itemDescriptor.serialName)) + } + } + +@OptIn(ExperimentalSerializationApi::class) +internal fun JsonObjectBuilder.addEnumDescription(descriptor: SerialDescriptor): JsonElement? { + put("type", typeNameFromKind(SerialKind.ENUM)) + return put("enum", JsonArray(descriptor.elementNames.map { JsonPrimitive(it) })) +} + +@OptIn(ExperimentalSerializationApi::class) +internal fun JsonObjectBuilder.addObjectProperties(descriptor: SerialDescriptor): JsonElement? { + return putJsonObject("properties") { + for (i in 0 until descriptor.elementsCount) { + val elementDescriptor = descriptor.getElementDescriptor(i) + val elementName = descriptor.getElementName(i) + putJsonObject(elementName) { + when (elementDescriptor.kind) { + StructureKind.LIST -> { + put("type", typeNameFromKind(elementDescriptor.kind)) + addListDescription(elementDescriptor) + } + StructureKind.CLASS -> { + if (elementDescriptor.serialName.startsWith("FirstOrdinalSerializer")) { + addEnumDescription(elementDescriptor) + } else { + put("\$ref", simpleNameFromSerialName(elementDescriptor.serialName)) + } + } + StructureKind.MAP -> { + put("type", typeNameFromKind(elementDescriptor.kind)) + putJsonObject("additionalProperties") { + put( + "\$ref", + simpleNameFromSerialName(elementDescriptor.getElementDescriptor(1).serialName) + ) + } + } + else -> { + put("type", typeNameFromKind(elementDescriptor.kind)) + } + } + } + } + } +} + +@OptIn(ExperimentalSerializationApi::class) +internal fun typeNameFromKind(kind: SerialKind): String { + return when (kind) { + PrimitiveKind.BOOLEAN -> "boolean" + PrimitiveKind.BYTE -> "integer" + PrimitiveKind.CHAR -> "string" + PrimitiveKind.DOUBLE -> "number" + PrimitiveKind.FLOAT -> "number" + PrimitiveKind.INT -> "integer" + PrimitiveKind.LONG -> "integer" + PrimitiveKind.SHORT -> "integer" + PrimitiveKind.STRING -> "string" + StructureKind.CLASS -> "object" + StructureKind.LIST -> "array" + SerialKind.ENUM -> "string" + StructureKind.MAP -> "object" + /* Only add new cases if they show up in actual test scenarios. */ + else -> TODO() + } +} + +/** + * Extracts the name expected for a class from its serial name. + * + * Our serialization classes are nested within the public-facing classes, and that's the name we + * want in the json output. There are two class names + * + * - `com.google.firebase.vertexai.type.Content.Internal` for regular scenarios + * - `com.google.firebase.vertexai.type.Content.Internal.SomeClass` for nested classes in the + * serializer. + * + * For the later time we need the second to last component, for the former we need the last + * component. + * + * Additionally, given that types can be nullable, we need to strip the `?` from the end of the + * name. + */ +internal fun simpleNameFromSerialName(serialName: String): String = + serialName + .split(".") + .let { + if (it.last().startsWith("Internal")) { + it[it.size - 2] + } else { + it.last() + } + } + .replace("?", "")