Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerialName
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.buildClassSerialDescriptor
import kotlinx.serialization.descriptors.element
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder

Expand All @@ -36,7 +37,12 @@ import kotlinx.serialization.encoding.Encoder
*/
internal class FirstOrdinalSerializer<T : Enum<T>>(private val enumClass: KClass<T>) :
KSerializer<T> {
override val descriptor: SerialDescriptor = buildClassSerialDescriptor("FirstOrdinalSerializer")
override val descriptor: SerialDescriptor =
buildClassSerialDescriptor("FirstOrdinalSerializer") {
for (enumValue in enumClass.enumValues()) {
element<String>(enumValue.toString())
}
}

override fun deserialize(decoder: Decoder): T {
val name = decoder.decodeString()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.vertexai

import com.google.firebase.vertexai.common.util.descriptorToJson
import com.google.firebase.vertexai.type.Candidate
import com.google.firebase.vertexai.type.CountTokensResponse
import com.google.firebase.vertexai.type.GenerateContentResponse
import com.google.firebase.vertexai.type.ModalityTokenCount
import com.google.firebase.vertexai.type.Schema
import io.kotest.assertions.json.shouldEqualJson
import org.junit.Test

internal class SerializationTests {
@Test
fun `test countTokensResponse serialization as Json`() {
val expectedJsonAsString =
"""
{
"id": "CountTokensResponse",
"type": "object",
"properties": {
"totalTokens": {
"type": "integer"
},
"totalBillableCharacters": {
"type": "integer"
},
"promptTokensDetails": {
"type": "array",
"items": {
"${'$'}ref": "ModalityTokenCount"
}
}
}
}
"""
.trimIndent()
val actualJson = descriptorToJson(CountTokensResponse.Internal.serializer().descriptor)
expectedJsonAsString shouldEqualJson actualJson.toString()
}

@Test
fun `test modalityTokenCount serialization as Json`() {
val expectedJsonAsString =
"""
{
"id": "ModalityTokenCount",
"type": "object",
"properties": {
"modality": {
"type": "string",
"enum": [
"UNSPECIFIED",
"TEXT",
"IMAGE",
"VIDEO",
"AUDIO",
"DOCUMENT"
]
},
"tokenCount": {
"type": "integer"
}
}
}
"""
.trimIndent()
val actualJson = descriptorToJson(ModalityTokenCount.Internal.serializer().descriptor)
expectedJsonAsString shouldEqualJson actualJson.toString()
}

@Test
fun `test GenerateContentResponse serialization as Json`() {
val expectedJsonAsString =
"""
{
"id": "GenerateContentResponse",
"type": "object",
"properties": {
"candidates": {
"type": "array",
"items": {
"${'$'}ref": "Candidate"
}
},
"promptFeedback": {
"${'$'}ref": "PromptFeedback"
},
"usageMetadata": {
"${'$'}ref": "UsageMetadata"
}
}
}
"""
.trimIndent()
val actualJson = descriptorToJson(GenerateContentResponse.Internal.serializer().descriptor)
expectedJsonAsString shouldEqualJson actualJson.toString()
}

@Test
fun `test Candidate serialization as Json`() {
val expectedJsonAsString =
"""
{
"id": "Candidate",
"type": "object",
"properties": {
"content": {
"${'$'}ref": "Content"
},
"finishReason": {
"type": "string",
"enum": [
"UNKNOWN",
"UNSPECIFIED",
"STOP",
"MAX_TOKENS",
"SAFETY",
"RECITATION",
"OTHER",
"BLOCKLIST",
"PROHIBITED_CONTENT",
"SPII",
"MALFORMED_FUNCTION_CALL"
]
},
"safetyRatings": {
"type": "array",
"items": {
"${'$'}ref": "SafetyRating"
}
},
"citationMetadata": {
"${'$'}ref": "CitationMetadata"
},
"groundingMetadata": {
"${'$'}ref": "GroundingMetadata"
}
}
}
"""
.trimIndent()
val actualJson = descriptorToJson(Candidate.Internal.serializer().descriptor)
expectedJsonAsString shouldEqualJson actualJson.toString()
}

@Test
fun `test Schema serialization as Json`() {
// Unlike the actual schema in the background, we don't represent "type" as an enum, but
// rather as a string. This is because we restrict we values can be used using helper methods
// rather than type
val expectedJsonAsString =
"""
{
"id": "Schema",
"type": "object",
"properties": {
"type": {
"type": "string"
},
"format": {
"type": "string"
},
"description": {
"type": "string"
},
"nullable": {
"type": "boolean"
},
"items": {
"${'$'}ref": "Schema"
},
"enum": {
"type": "array",
"items": {
"type": "string"
}
},
"properties": {
"type": "object",
"additionalProperties": {
"${'$'}ref": "Schema"
}
},
"required": {
"type": "array",
"items": {
"type": "string"
}
}
}
}
"""
.trimIndent()
val actualJson = descriptorToJson(Schema.Internal.serializer().descriptor)
expectedJsonAsString shouldEqualJson actualJson.toString()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.vertexai.common.util

import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.descriptors.StructureKind
import kotlinx.serialization.descriptors.elementDescriptors
import kotlinx.serialization.descriptors.elementNames
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.JsonObjectBuilder
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.put
import kotlinx.serialization.json.putJsonObject

/**
* Returns a [JsonObject] representing the classes in the hierarchy of a serialization [descriptor].
*
* The format of the JSON object is similar to that of a Discovery Document, but restricted to these
* fields:
* - id
* - type
* - properties
* - items
* - $ref
*
* @param descriptor The [SerialDescriptor] to process.
*/
@OptIn(ExperimentalSerializationApi::class)
internal fun descriptorToJson(descriptor: SerialDescriptor): JsonObject {
return buildJsonObject {
put("id", simpleNameFromSerialName(descriptor.serialName))
put("type", typeNameFromKind(descriptor.kind))
if (descriptor.kind != StructureKind.CLASS) {
throw UnsupportedOperationException("Only classes can be serialized to JSON for now.")
}
// For top-level enums, add them directly.
if (descriptor.serialName == "FirstOrdinalSerializer") {
addEnumDescription(descriptor)
} else {
addObjectProperties(descriptor)
}
}
}

@OptIn(ExperimentalSerializationApi::class)
internal fun JsonObjectBuilder.addListDescription(descriptor: SerialDescriptor) =
putJsonObject("items") {
val itemDescriptor = descriptor.elementDescriptors.first()
val nestedIsPrimitive = (descriptor.elementsCount == 1 && itemDescriptor.kind is PrimitiveKind)
if (nestedIsPrimitive) {
put("type", typeNameFromKind(itemDescriptor.kind))
} else {
put("\$ref", simpleNameFromSerialName(itemDescriptor.serialName))
}
}

@OptIn(ExperimentalSerializationApi::class)
internal fun JsonObjectBuilder.addEnumDescription(descriptor: SerialDescriptor): JsonElement? {
put("type", typeNameFromKind(SerialKind.ENUM))
return put("enum", JsonArray(descriptor.elementNames.map { JsonPrimitive(it) }))
}

@OptIn(ExperimentalSerializationApi::class)
internal fun JsonObjectBuilder.addObjectProperties(descriptor: SerialDescriptor): JsonElement? {
return putJsonObject("properties") {
for (i in 0 until descriptor.elementsCount) {
val elementDescriptor = descriptor.getElementDescriptor(i)
val elementName = descriptor.getElementName(i)
putJsonObject(elementName) {
when (elementDescriptor.kind) {
StructureKind.LIST -> {
put("type", typeNameFromKind(elementDescriptor.kind))
addListDescription(elementDescriptor)
}
StructureKind.CLASS -> {
if (elementDescriptor.serialName.startsWith("FirstOrdinalSerializer")) {
addEnumDescription(elementDescriptor)
} else {
put("\$ref", simpleNameFromSerialName(elementDescriptor.serialName))
}
}
StructureKind.MAP -> {
put("type", typeNameFromKind(elementDescriptor.kind))
putJsonObject("additionalProperties") {
put(
"\$ref",
simpleNameFromSerialName(elementDescriptor.getElementDescriptor(1).serialName)
)
}
}
else -> {
put("type", typeNameFromKind(elementDescriptor.kind))
}
}
}
}
}
}

@OptIn(ExperimentalSerializationApi::class)
internal fun typeNameFromKind(kind: SerialKind): String {
return when (kind) {
PrimitiveKind.BOOLEAN -> "boolean"
PrimitiveKind.BYTE -> "integer"
PrimitiveKind.CHAR -> "string"
PrimitiveKind.DOUBLE -> "number"
PrimitiveKind.FLOAT -> "number"
PrimitiveKind.INT -> "integer"
PrimitiveKind.LONG -> "integer"
PrimitiveKind.SHORT -> "integer"
PrimitiveKind.STRING -> "string"
StructureKind.CLASS -> "object"
StructureKind.LIST -> "array"
SerialKind.ENUM -> "string"
StructureKind.MAP -> "object"
else -> TODO()
}
}

internal fun simpleNameFromSerialName(serialName: String): String =
serialName
.split(".")
.let {
if (it.last().startsWith("Internal")) {
it[it.size - 2]
} else {
it.last()
}
}
.replace("?", "")
Loading