Skip to content

Automatic Schema Generation Test #7217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions firebase-ai/firebase-ai.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ dependencies {
implementation(libs.kotlinx.coroutines.android)
implementation(libs.kotlinx.coroutines.reactive)
implementation(libs.reactive.streams)
implementation(libs.kotlin.reflect)
implementation("com.google.guava:listenablefuture:1.0")
implementation("androidx.concurrent:concurrent-futures:1.2.0")
implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.google.firebase.ai.annotation

import kotlin.reflect.KClass

@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.CLASS, AnnotationTarget.PROPERTY)
public annotation class SchemaDetails(val description: String, val title: String)

@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.PROPERTY)
public annotation class NumSchemaDetails(val minimum: Double, val maximum: Double)

@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.PROPERTY)
public annotation class ListSchemaDetails(val minItems: Int, val maxItems: Int, val clazz: KClass<*>)

@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.PROPERTY)
public annotation class StringSchemaDetails(val format: String)
142 changes: 142 additions & 0 deletions firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@

package com.google.firebase.ai.type

import com.google.firebase.ai.annotation.ListSchemaDetails
import com.google.firebase.ai.annotation.NumSchemaDetails
import com.google.firebase.ai.annotation.SchemaDetails
import com.google.firebase.ai.annotation.StringSchemaDetails
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlin.reflect.KClass
import kotlin.reflect.KProperty1
import kotlin.reflect.full.findAnnotations
import kotlin.reflect.full.memberProperties
import kotlin.reflect.jvm.jvmErasure

public abstract class StringFormat private constructor(internal val value: String) {
public class Custom(value: String) : StringFormat(value)
Expand Down Expand Up @@ -320,8 +330,140 @@ internal constructor(
*/
@JvmStatic
public fun anyOf(schemas: List<Schema>): Schema = Schema(type = "ANYOF", anyOf = schemas)

@JvmStatic
public fun fromClass(clazz: KClass<*>, nullable: Boolean = false): Schema {
return fromClassHelper(clazz, nullable)
}

@JvmStatic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The @JvmStatic annotation on a private function has no effect and can be misleading. It should be removed for code clarity.

private fun fromClassHelper(
clazz: KClass<*>,
nullable: Boolean = false,
propertyName: String? = null,
schemaDetails: SchemaDetails? = null,
numSchemaDetails: NumSchemaDetails? = null,
listSchemaDetails: ListSchemaDetails? = null,
stringSchemaDetails: StringSchemaDetails? = null
): Schema {
return when (clazz) {
Int::class -> {
integer(
schemaDetails?.description,
nullable,
schemaDetails?.title,
numSchemaDetails?.minimum,
numSchemaDetails?.maximum
)
Comment on lines +350 to +357
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The enclosing braces {} are redundant here and can be removed for conciseness. This also applies to the other primitive type branches (Long, Boolean, Float, Double) in this when statement.

integer(
            schemaDetails?.description,
            nullable,
            schemaDetails?.title,
            numSchemaDetails?.minimum,
            numSchemaDetails?.maximum
          )

}
Long::class -> {
long(
schemaDetails?.description,
nullable,
schemaDetails?.title,
numSchemaDetails?.minimum,
numSchemaDetails?.maximum
)
}
Boolean::class -> {
boolean(
schemaDetails?.description,
nullable,
schemaDetails?.title
)
}
Float::class -> {
float(
schemaDetails?.description,
nullable,
schemaDetails?.title,
numSchemaDetails?.minimum,
numSchemaDetails?.maximum
)
}
Double::class -> {
double(
schemaDetails?.description,
nullable,
schemaDetails?.title,
numSchemaDetails?.minimum,
numSchemaDetails?.maximum
)
}
String::class -> {
string(
schemaDetails?.description,
nullable,
stringSchemaDetails?.format?.let { StringFormat.Custom(it) },
schemaDetails?.title
)
}

List::class -> {
if (listSchemaDetails == null) {
throw IllegalStateException(
"${clazz.simpleName}$${propertyName} must include " +
"@ListSchemaDetails to use automatic schema generation."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message when @ListSchemaDetails is missing can be confusing. The use of $ to join the class and property name is unconventional, and propertyName can be null, resulting in a message like List$null.... A clearer error message would improve usability.

val subject = propertyName?.let { "List property '$it'" } ?: "A top-level List"
throw IllegalStateException(
  "$subject must include @ListSchemaDetails to use automatic schema generation."
)

}
array(
fromClassHelper(listSchemaDetails.clazz),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are a couple of issues with List property handling:

  1. Annotations not applied to items: When generating a schema for a List property, annotations like @NumSchemaDetails or @StringSchemaDetails on that property are ignored. They should be applied to the schema of the list's items. The recursive call to fromClassHelper for the list item type should pass along these annotations.

  2. Item nullability not supported: The nullability of list items (e.g., List<String?>) is not supported. The clazz property in @ListSchemaDetails erases this information. A more robust solution would be to remove the clazz parameter from @ListSchemaDetails and instead use reflection to inspect the property's type arguments to determine the item's type and nullability. While this would be a breaking change to this new API, it's a significant limitation worth addressing.

Here's a suggested fix for passing down the annotations:

fromClassHelper(clazz = listSchemaDetails.clazz, numSchemaDetails = numSchemaDetails, stringSchemaDetails = stringSchemaDetails)

schemaDetails?.description,
nullable,
schemaDetails?.title,
listSchemaDetails.minItems,
listSchemaDetails.maxItems
)
}
else -> {
val isSerializable = clazz.findAnnotations(Serializable::class).isNotEmpty()
if (!isSerializable) {
throw IllegalStateException("${clazz.simpleName} must be @Serializable to use automatic " +
"schema generation.")
}
if (!clazz.isData) {
throw IllegalStateException("${clazz.simpleName} must be a data class to use automatic " +
"schema generation.")
}
val classSchemaDetails = schemaDetails ?: clazz.findAnnotations(SchemaDetails::class).firstOrNull()
?: throw IllegalStateException("${clazz.simpleName} must include @SchemaDetails to use " +
"automatic schema generation.")
val properties = clazz.memberProperties.associate { property: KProperty1<out Any, *> ->
val propertyDetails = property.findAnnotations(SchemaDetails::class).firstOrNull()
val stringDetails =
property.findAnnotations(StringSchemaDetails::class).firstOrNull()
val numDetails = property.findAnnotations(NumSchemaDetails::class).firstOrNull()
val listDetails = property.findAnnotations(ListSchemaDetails::class).firstOrNull()
val serialName = property.findAnnotations(SerialName::class).firstOrNull()
val deepPropertyName = serialName?.value ?: property.name
val propertyClass = property.returnType
Pair(
deepPropertyName,
fromClassHelper(
propertyClass.jvmErasure,
propertyClass.isMarkedNullable,
deepPropertyName,
propertyDetails,
numDetails,
listDetails,
stringDetails
)
)
}
obj(
properties,
emptyList(),
classSchemaDetails.description,
nullable,
classSchemaDetails.title
)
}
}
}
}



internal fun toInternal(): Internal {
val cleanedType =
if (type == "ANYOF") {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package com.google.firebase.ai

import com.google.firebase.ai.annotation.ListSchemaDetails
import com.google.firebase.ai.annotation.NumSchemaDetails
import com.google.firebase.ai.annotation.SchemaDetails
import com.google.firebase.ai.annotation.StringSchemaDetails
import com.google.firebase.ai.type.Schema
import com.google.firebase.ai.type.StringFormat
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import org.junit.Test

class SchemaGenerationTest {

@Test
fun testSchemaGeneration() {
val generatedSchema = Schema.fromClass(TestClass1::class)
val schema =
Schema.obj(
description = "A test class (1)",
title = "TestClass1",
properties =
mapOf(
"val1" to Schema.integer("A test field (1)", false, "var1"),
"val2" to Schema.long("A test field (2)", false, "var2", 20.0, 30.0),
"val3" to Schema.boolean("A test field (3)", false, "var3"),
"val4" to Schema.float("A test field (4)", false, "var4"),
"val5" to Schema.double("A test field (5)", false, "var5"),
"val6" to
Schema.string("A test field (6)", false, StringFormat.Custom("StringFormat"), "var6"),
"val7" to
Schema.array(
Schema.obj(
mapOf("val1" to Schema.integer(nullable = true)),
emptyList(),
"A test class (2)",
false,
"TestClass2",
),
"A test field (7)",
false,
"var7",
0,
500,
),
"val8" to
Schema.obj(
mapOf("customSerialName" to Schema.array(Schema.string(), minItems = 0, maxItems = 500)),
emptyList(),
"A test field (8)",
false,
"var8",
),
),
)
assert(schema.toInternal() == generatedSchema.toInternal())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's better to use an assertion library like Kotest (which is already a dependency in this project) instead of the standard assert function. This provides more informative failure messages on test failure. You'll need to add import io.kotest.matchers.shouldBe to the file.

Suggested change
assert(schema.toInternal() == generatedSchema.toInternal())
generatedSchema.toInternal() shouldBe schema.toInternal()

}

@Serializable
@SchemaDetails("A test class (1)", "TestClass1")
data class TestClass1(
@SchemaDetails("A test field (1)", "var1")
val val1: Int,
@NumSchemaDetails(minimum = 20.0, maximum = 30.0)
@SchemaDetails("A test field (2)", "var2")
val val2: Long,
@SchemaDetails("A test field (3)", "var3")
val val3: Boolean,
@SchemaDetails("A test field (4)", "var4")
val val4: Float,
@SchemaDetails("A test field (5)", "var5")
val val5: Double,
@SchemaDetails("A test field (6)", "var6")
@StringSchemaDetails("StringFormat")
val val6: String,
@SchemaDetails("A test field (7)", "var7")
@ListSchemaDetails(0, 500, TestClass2::class)
val val7: List<TestClass2>,
@SchemaDetails("A test field (8)", "var8")
val val8: TestClass3,
)

@Serializable
@SchemaDetails("A test class (2)", "TestClass2")
data class TestClass2(val val1: Int?)

@Serializable
@SchemaDetails("A test class (3)", "TestClass3")
data class TestClass3(
@ListSchemaDetails(0, 500, String::class)
@SerialName("customSerialName")
val val1: List<String>
)
}
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ kotlin-bom = { module = "org.jetbrains.kotlin:kotlin-bom", version.ref = "kotlin
kotlin-coroutines-tasks = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-play-services", version.ref = "coroutines" }
kotlin-stdlib = { module = "org.jetbrains.kotlin:kotlin-stdlib", version.ref = "kotlin" }
kotlin-stdlib-jdk8 = { module = "org.jetbrains.kotlin:kotlin-stdlib-jdk8", version.ref = "kotlin" }
kotlin-reflect = { module = "org.jetbrains.kotlin:kotlin-reflect", version.ref = "kotlin" }
kotlinx-coroutines-android = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-android", version.ref = "coroutines" }
kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "coroutines" }
kotlinx-coroutines-reactive = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-reactive", version.ref = "coroutines" }
Expand Down
Loading