Skip to content

Automatic Schema Generation Test #7217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions firebase-ai/api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,38 @@ package com.google.firebase.ai {

}

package com.google.firebase.ai.annotation {

@kotlin.annotation.Retention(kotlin.annotation.AnnotationRetention.RUNTIME) @kotlin.annotation.Target(allowedTargets=kotlin.annotation.AnnotationTarget.PROPERTY) public @interface ListSchemaDetails {
method public abstract kotlin.reflect.KClass<? extends java.lang.Object?> clazz();
method public abstract int maxItems();
method public abstract int minItems();
property public abstract kotlin.reflect.KClass<? extends java.lang.Object?> clazz;
property public abstract int maxItems;
property public abstract int minItems;
}

@kotlin.annotation.Retention(kotlin.annotation.AnnotationRetention.RUNTIME) @kotlin.annotation.Target(allowedTargets=kotlin.annotation.AnnotationTarget.PROPERTY) public @interface NumSchemaDetails {
method public abstract double maximum();
method public abstract double minimum();
property public abstract double maximum;
property public abstract double minimum;
}

@kotlin.annotation.Retention(kotlin.annotation.AnnotationRetention.RUNTIME) @kotlin.annotation.Target(allowedTargets={kotlin.annotation.AnnotationTarget.CLASS, kotlin.annotation.AnnotationTarget.PROPERTY}) public @interface SchemaDetails {
method public abstract String description();
method public abstract String title();
property public abstract String description;
property public abstract String title;
}

@kotlin.annotation.Retention(kotlin.annotation.AnnotationRetention.RUNTIME) @kotlin.annotation.Target(allowedTargets=kotlin.annotation.AnnotationTarget.PROPERTY) public @interface StringSchemaDetails {
method public abstract String format();
property public abstract String format;
}

}

package com.google.firebase.ai.java {

public abstract class ChatFutures {
Expand Down Expand Up @@ -955,6 +987,7 @@ package com.google.firebase.ai.type {
method public static com.google.firebase.ai.type.Schema enumeration(java.util.List<java.lang.String> values, String? description = null);
method public static com.google.firebase.ai.type.Schema enumeration(java.util.List<java.lang.String> values, String? description = null, boolean nullable = false);
method public static com.google.firebase.ai.type.Schema enumeration(java.util.List<java.lang.String> values, String? description = null, boolean nullable = false, String? title = null);
method public static com.google.firebase.ai.type.Schema fromClass(kotlin.reflect.KClass<? extends java.lang.Object?> clazz, boolean nullable = false);
method public java.util.List<com.google.firebase.ai.type.Schema>? getAnyOf();
method public String? getDescription();
method public java.util.List<java.lang.String>? getEnum();
Expand Down Expand Up @@ -1036,6 +1069,7 @@ package com.google.firebase.ai.type {
method public com.google.firebase.ai.type.Schema enumeration(java.util.List<java.lang.String> values, String? description = null);
method public com.google.firebase.ai.type.Schema enumeration(java.util.List<java.lang.String> values, String? description = null, boolean nullable = false);
method public com.google.firebase.ai.type.Schema enumeration(java.util.List<java.lang.String> values, String? description = null, boolean nullable = false, String? title = null);
method public com.google.firebase.ai.type.Schema fromClass(kotlin.reflect.KClass<? extends java.lang.Object?> clazz, boolean nullable = false);
method public com.google.firebase.ai.type.Schema numDouble();
method public com.google.firebase.ai.type.Schema numDouble(String? description = null);
method public com.google.firebase.ai.type.Schema numDouble(String? description = null, boolean nullable = false);
Expand Down
1 change: 1 addition & 0 deletions firebase-ai/firebase-ai.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ dependencies {
implementation(libs.kotlinx.coroutines.android)
implementation(libs.kotlinx.coroutines.reactive)
implementation(libs.reactive.streams)
implementation(libs.kotlin.reflect)
implementation("com.google.guava:listenablefuture:1.0")
implementation("androidx.concurrent:concurrent-futures:1.2.0")
implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.firebase.ai.annotation

import kotlin.reflect.KClass

@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.CLASS, AnnotationTarget.PROPERTY)
public annotation class SchemaDetails(val description: String, val title: String)

@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.PROPERTY)
public annotation class NumSchemaDetails(val minimum: Double, val maximum: Double)

@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.PROPERTY)
public annotation class ListSchemaDetails(
val minItems: Int,
val maxItems: Int,
val clazz: KClass<*>
)

@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.PROPERTY)
public annotation class StringSchemaDetails(val format: String)
141 changes: 141 additions & 0 deletions firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@

package com.google.firebase.ai.type

import com.google.firebase.ai.annotation.ListSchemaDetails
import com.google.firebase.ai.annotation.NumSchemaDetails
import com.google.firebase.ai.annotation.SchemaDetails
import com.google.firebase.ai.annotation.StringSchemaDetails
import kotlin.reflect.KClass
import kotlin.reflect.KProperty1
import kotlin.reflect.full.findAnnotations
import kotlin.reflect.full.memberProperties
import kotlin.reflect.jvm.jvmErasure
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

public abstract class StringFormat private constructor(internal val value: String) {
Expand Down Expand Up @@ -320,6 +330,137 @@ internal constructor(
*/
@JvmStatic
public fun anyOf(schemas: List<Schema>): Schema = Schema(type = "ANYOF", anyOf = schemas)

@JvmStatic
public fun fromClass(clazz: KClass<*>, nullable: Boolean = false): Schema {
return fromClassHelper(clazz, nullable)
}

@JvmStatic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The @JvmStatic annotation on a private function has no effect and can be misleading. It should be removed for code clarity.

private fun fromClassHelper(
clazz: KClass<*>,
nullable: Boolean = false,
propertyName: String? = null,
schemaDetails: SchemaDetails? = null,
numSchemaDetails: NumSchemaDetails? = null,
listSchemaDetails: ListSchemaDetails? = null,
stringSchemaDetails: StringSchemaDetails? = null
): Schema {
return when (clazz) {
Int::class -> {
integer(
schemaDetails?.description,
nullable,
schemaDetails?.title,
numSchemaDetails?.minimum,
numSchemaDetails?.maximum
)
Comment on lines +350 to +357
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The enclosing braces {} are redundant here and can be removed for conciseness. This also applies to the other primitive type branches (Long, Boolean, Float, Double) in this when statement.

integer(
            schemaDetails?.description,
            nullable,
            schemaDetails?.title,
            numSchemaDetails?.minimum,
            numSchemaDetails?.maximum
          )

}
Long::class -> {
long(
schemaDetails?.description,
nullable,
schemaDetails?.title,
numSchemaDetails?.minimum,
numSchemaDetails?.maximum
)
}
Boolean::class -> {
boolean(schemaDetails?.description, nullable, schemaDetails?.title)
}
Float::class -> {
float(
schemaDetails?.description,
nullable,
schemaDetails?.title,
numSchemaDetails?.minimum,
numSchemaDetails?.maximum
)
}
Double::class -> {
double(
schemaDetails?.description,
nullable,
schemaDetails?.title,
numSchemaDetails?.minimum,
numSchemaDetails?.maximum
)
}
String::class -> {
string(
schemaDetails?.description,
nullable,
stringSchemaDetails?.format?.let { StringFormat.Custom(it) },
schemaDetails?.title
)
}
List::class -> {
if (listSchemaDetails == null) {
throw IllegalStateException(
"${clazz.simpleName}$${propertyName} must include " +
"@ListSchemaDetails to use automatic schema generation."
)
}
array(
fromClassHelper(listSchemaDetails.clazz),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are a couple of issues with List property handling:

  1. Annotations not applied to items: When generating a schema for a List property, annotations like @NumSchemaDetails or @StringSchemaDetails on that property are ignored. They should be applied to the schema of the list's items. The recursive call to fromClassHelper for the list item type should pass along these annotations.

  2. Item nullability not supported: The nullability of list items (e.g., List<String?>) is not supported. The clazz property in @ListSchemaDetails erases this information. A more robust solution would be to remove the clazz parameter from @ListSchemaDetails and instead use reflection to inspect the property's type arguments to determine the item's type and nullability. While this would be a breaking change to this new API, it's a significant limitation worth addressing.

Here's a suggested fix for passing down the annotations:

fromClassHelper(clazz = listSchemaDetails.clazz, numSchemaDetails = numSchemaDetails, stringSchemaDetails = stringSchemaDetails)

schemaDetails?.description,
nullable,
schemaDetails?.title,
listSchemaDetails.minItems,
listSchemaDetails.maxItems
)
}
else -> {
val isSerializable = clazz.findAnnotations(Serializable::class).isNotEmpty()
if (!isSerializable) {
throw IllegalStateException(
"${clazz.simpleName} must be @Serializable to use automatic " + "schema generation."
)
}
if (!clazz.isData) {
throw IllegalStateException(
"${clazz.simpleName} must be a data class to use automatic " + "schema generation."
)
}
val classSchemaDetails =
schemaDetails
?: clazz.findAnnotations(SchemaDetails::class).firstOrNull()
?: throw IllegalStateException(
"${clazz.simpleName} must include @SchemaDetails to use " +
"automatic schema generation."
)
val properties =
clazz.memberProperties.associate { property: KProperty1<out Any, *> ->
val propertyDetails = property.findAnnotations(SchemaDetails::class).firstOrNull()
val stringDetails = property.findAnnotations(StringSchemaDetails::class).firstOrNull()
val numDetails = property.findAnnotations(NumSchemaDetails::class).firstOrNull()
val listDetails = property.findAnnotations(ListSchemaDetails::class).firstOrNull()
val serialName = property.findAnnotations(SerialName::class).firstOrNull()
val deepPropertyName = serialName?.value ?: property.name
val propertyClass = property.returnType
Pair(
deepPropertyName,
fromClassHelper(
propertyClass.jvmErasure,
propertyClass.isMarkedNullable,
deepPropertyName,
propertyDetails,
numDetails,
listDetails,
stringDetails
)
)
}
obj(
properties,
emptyList(),
classSchemaDetails.description,
nullable,
classSchemaDetails.title
)
}
}
}
}

internal fun toInternal(): Internal {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.google.firebase.ai

import com.google.firebase.ai.annotation.ListSchemaDetails
import com.google.firebase.ai.annotation.NumSchemaDetails
import com.google.firebase.ai.annotation.SchemaDetails
import com.google.firebase.ai.annotation.StringSchemaDetails
import com.google.firebase.ai.type.Schema
import com.google.firebase.ai.type.StringFormat
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import org.junit.Test

class SchemaGenerationTest {

@Test
fun testSchemaGeneration() {
val generatedSchema = Schema.fromClass(TestClass1::class)
val schema =
Schema.obj(
description = "A test class (1)",
title = "TestClass1",
properties =
mapOf(
"val1" to Schema.integer("A test field (1)", false, "var1"),
"val2" to Schema.long("A test field (2)", false, "var2", 20.0, 30.0),
"val3" to Schema.boolean("A test field (3)", false, "var3"),
"val4" to Schema.float("A test field (4)", false, "var4"),
"val5" to Schema.double("A test field (5)", false, "var5"),
"val6" to
Schema.string("A test field (6)", false, StringFormat.Custom("StringFormat"), "var6"),
"val7" to
Schema.array(
Schema.obj(
mapOf("val1" to Schema.integer(nullable = true)),
emptyList(),
"A test class (2)",
false,
"TestClass2",
),
"A test field (7)",
false,
"var7",
0,
500,
),
"val8" to
Schema.obj(
mapOf(
"customSerialName" to Schema.array(Schema.string(), minItems = 0, maxItems = 500)
),
emptyList(),
"A test field (8)",
false,
"var8",
),
),
)
assert(schema.toInternal() == generatedSchema.toInternal())
}

@Serializable
@SchemaDetails("A test class (1)", "TestClass1")
data class TestClass1(
@SchemaDetails("A test field (1)", "var1") val val1: Int,
@NumSchemaDetails(minimum = 20.0, maximum = 30.0)
@SchemaDetails("A test field (2)", "var2")
val val2: Long,
@SchemaDetails("A test field (3)", "var3") val val3: Boolean,
@SchemaDetails("A test field (4)", "var4") val val4: Float,
@SchemaDetails("A test field (5)", "var5") val val5: Double,
@SchemaDetails("A test field (6)", "var6")
@StringSchemaDetails("StringFormat")
val val6: String,
@SchemaDetails("A test field (7)", "var7")
@ListSchemaDetails(0, 500, TestClass2::class)
val val7: List<TestClass2>,
@SchemaDetails("A test field (8)", "var8") val val8: TestClass3,
)

@Serializable
@SchemaDetails("A test class (2)", "TestClass2")
data class TestClass2(val val1: Int?)

@Serializable
@SchemaDetails("A test class (3)", "TestClass3")
data class TestClass3(
@ListSchemaDetails(0, 500, String::class) @SerialName("customSerialName") val val1: List<String>
)
}
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ kotlin-bom = { module = "org.jetbrains.kotlin:kotlin-bom", version.ref = "kotlin
kotlin-coroutines-tasks = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-play-services", version.ref = "coroutines" }
kotlin-stdlib = { module = "org.jetbrains.kotlin:kotlin-stdlib", version.ref = "kotlin" }
kotlin-stdlib-jdk8 = { module = "org.jetbrains.kotlin:kotlin-stdlib-jdk8", version.ref = "kotlin" }
kotlin-reflect = { module = "org.jetbrains.kotlin:kotlin-reflect", version.ref = "kotlin" }
kotlinx-coroutines-android = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-android", version.ref = "coroutines" }
kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "coroutines" }
kotlinx-coroutines-reactive = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-reactive", version.ref = "coroutines" }
Expand Down
Loading