diff --git a/firebase-ai/api.txt b/firebase-ai/api.txt index b979c361b52..b6f867d91b9 100644 --- a/firebase-ai/api.txt +++ b/firebase-ai/api.txt @@ -80,6 +80,38 @@ package com.google.firebase.ai { } +package com.google.firebase.ai.annotation { + + @kotlin.annotation.Retention(kotlin.annotation.AnnotationRetention.RUNTIME) @kotlin.annotation.Target(allowedTargets=kotlin.annotation.AnnotationTarget.PROPERTY) public @interface ListSchemaDetails { + method public abstract kotlin.reflect.KClass clazz(); + method public abstract int maxItems(); + method public abstract int minItems(); + property public abstract kotlin.reflect.KClass clazz; + property public abstract int maxItems; + property public abstract int minItems; + } + + @kotlin.annotation.Retention(kotlin.annotation.AnnotationRetention.RUNTIME) @kotlin.annotation.Target(allowedTargets=kotlin.annotation.AnnotationTarget.PROPERTY) public @interface NumSchemaDetails { + method public abstract double maximum(); + method public abstract double minimum(); + property public abstract double maximum; + property public abstract double minimum; + } + + @kotlin.annotation.Retention(kotlin.annotation.AnnotationRetention.RUNTIME) @kotlin.annotation.Target(allowedTargets={kotlin.annotation.AnnotationTarget.CLASS, kotlin.annotation.AnnotationTarget.PROPERTY}) public @interface SchemaDetails { + method public abstract String description(); + method public abstract String title(); + property public abstract String description; + property public abstract String title; + } + + @kotlin.annotation.Retention(kotlin.annotation.AnnotationRetention.RUNTIME) @kotlin.annotation.Target(allowedTargets=kotlin.annotation.AnnotationTarget.PROPERTY) public @interface StringSchemaDetails { + method public abstract String format(); + property public abstract String format; + } + +} + package com.google.firebase.ai.java { public abstract class ChatFutures { @@ -955,6 +987,7 @@ package com.google.firebase.ai.type { method public static com.google.firebase.ai.type.Schema enumeration(java.util.List values, String? description = null); method public static com.google.firebase.ai.type.Schema enumeration(java.util.List values, String? description = null, boolean nullable = false); method public static com.google.firebase.ai.type.Schema enumeration(java.util.List values, String? description = null, boolean nullable = false, String? title = null); + method public static com.google.firebase.ai.type.Schema fromClass(kotlin.reflect.KClass clazz, boolean nullable = false); method public java.util.List? getAnyOf(); method public String? getDescription(); method public java.util.List? getEnum(); @@ -1036,6 +1069,7 @@ package com.google.firebase.ai.type { method public com.google.firebase.ai.type.Schema enumeration(java.util.List values, String? description = null); method public com.google.firebase.ai.type.Schema enumeration(java.util.List values, String? description = null, boolean nullable = false); method public com.google.firebase.ai.type.Schema enumeration(java.util.List values, String? description = null, boolean nullable = false, String? title = null); + method public com.google.firebase.ai.type.Schema fromClass(kotlin.reflect.KClass clazz, boolean nullable = false); method public com.google.firebase.ai.type.Schema numDouble(); method public com.google.firebase.ai.type.Schema numDouble(String? description = null); method public com.google.firebase.ai.type.Schema numDouble(String? description = null, boolean nullable = false); diff --git a/firebase-ai/firebase-ai.gradle.kts b/firebase-ai/firebase-ai.gradle.kts index ba7f21b56fb..be0022b7184 100644 --- a/firebase-ai/firebase-ai.gradle.kts +++ b/firebase-ai/firebase-ai.gradle.kts @@ -105,6 +105,7 @@ dependencies { implementation(libs.kotlinx.coroutines.android) implementation(libs.kotlinx.coroutines.reactive) implementation(libs.reactive.streams) + implementation(libs.kotlin.reflect) implementation("com.google.guava:listenablefuture:1.0") implementation("androidx.concurrent:concurrent-futures:1.2.0") implementation("androidx.concurrent:concurrent-futures-ktx:1.2.0") diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotation/SchemaDetails.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotation/SchemaDetails.kt new file mode 100644 index 00000000000..88dfbfbaa0b --- /dev/null +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotation/SchemaDetails.kt @@ -0,0 +1,38 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.firebase.ai.annotation + +import kotlin.reflect.KClass + +@Retention(AnnotationRetention.RUNTIME) +@Target(AnnotationTarget.CLASS, AnnotationTarget.PROPERTY) +public annotation class SchemaDetails(val description: String, val title: String) + +@Retention(AnnotationRetention.RUNTIME) +@Target(AnnotationTarget.PROPERTY) +public annotation class NumSchemaDetails(val minimum: Double, val maximum: Double) + +@Retention(AnnotationRetention.RUNTIME) +@Target(AnnotationTarget.PROPERTY) +public annotation class ListSchemaDetails( + val minItems: Int, + val maxItems: Int, + val clazz: KClass<*> +) + +@Retention(AnnotationRetention.RUNTIME) +@Target(AnnotationTarget.PROPERTY) +public annotation class StringSchemaDetails(val format: String) diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt index 5f2f6ca9350..efb2ae38045 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt @@ -16,6 +16,16 @@ package com.google.firebase.ai.type +import com.google.firebase.ai.annotation.ListSchemaDetails +import com.google.firebase.ai.annotation.NumSchemaDetails +import com.google.firebase.ai.annotation.SchemaDetails +import com.google.firebase.ai.annotation.StringSchemaDetails +import kotlin.reflect.KClass +import kotlin.reflect.KProperty1 +import kotlin.reflect.full.findAnnotations +import kotlin.reflect.full.memberProperties +import kotlin.reflect.jvm.jvmErasure +import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable public abstract class StringFormat private constructor(internal val value: String) { @@ -320,6 +330,137 @@ internal constructor( */ @JvmStatic public fun anyOf(schemas: List): Schema = Schema(type = "ANYOF", anyOf = schemas) + + @JvmStatic + public fun fromClass(clazz: KClass<*>, nullable: Boolean = false): Schema { + return fromClassHelper(clazz, nullable) + } + + @JvmStatic + private fun fromClassHelper( + clazz: KClass<*>, + nullable: Boolean = false, + propertyName: String? = null, + schemaDetails: SchemaDetails? = null, + numSchemaDetails: NumSchemaDetails? = null, + listSchemaDetails: ListSchemaDetails? = null, + stringSchemaDetails: StringSchemaDetails? = null + ): Schema { + return when (clazz) { + Int::class -> { + integer( + schemaDetails?.description, + nullable, + schemaDetails?.title, + numSchemaDetails?.minimum, + numSchemaDetails?.maximum + ) + } + Long::class -> { + long( + schemaDetails?.description, + nullable, + schemaDetails?.title, + numSchemaDetails?.minimum, + numSchemaDetails?.maximum + ) + } + Boolean::class -> { + boolean(schemaDetails?.description, nullable, schemaDetails?.title) + } + Float::class -> { + float( + schemaDetails?.description, + nullable, + schemaDetails?.title, + numSchemaDetails?.minimum, + numSchemaDetails?.maximum + ) + } + Double::class -> { + double( + schemaDetails?.description, + nullable, + schemaDetails?.title, + numSchemaDetails?.minimum, + numSchemaDetails?.maximum + ) + } + String::class -> { + string( + schemaDetails?.description, + nullable, + stringSchemaDetails?.format?.let { StringFormat.Custom(it) }, + schemaDetails?.title + ) + } + List::class -> { + if (listSchemaDetails == null) { + throw IllegalStateException( + "${clazz.simpleName}$${propertyName} must include " + + "@ListSchemaDetails to use automatic schema generation." + ) + } + array( + fromClassHelper(listSchemaDetails.clazz), + schemaDetails?.description, + nullable, + schemaDetails?.title, + listSchemaDetails.minItems, + listSchemaDetails.maxItems + ) + } + else -> { + val isSerializable = clazz.findAnnotations(Serializable::class).isNotEmpty() + if (!isSerializable) { + throw IllegalStateException( + "${clazz.simpleName} must be @Serializable to use automatic " + "schema generation." + ) + } + if (!clazz.isData) { + throw IllegalStateException( + "${clazz.simpleName} must be a data class to use automatic " + "schema generation." + ) + } + val classSchemaDetails = + schemaDetails + ?: clazz.findAnnotations(SchemaDetails::class).firstOrNull() + ?: throw IllegalStateException( + "${clazz.simpleName} must include @SchemaDetails to use " + + "automatic schema generation." + ) + val properties = + clazz.memberProperties.associate { property: KProperty1 -> + val propertyDetails = property.findAnnotations(SchemaDetails::class).firstOrNull() + val stringDetails = property.findAnnotations(StringSchemaDetails::class).firstOrNull() + val numDetails = property.findAnnotations(NumSchemaDetails::class).firstOrNull() + val listDetails = property.findAnnotations(ListSchemaDetails::class).firstOrNull() + val serialName = property.findAnnotations(SerialName::class).firstOrNull() + val deepPropertyName = serialName?.value ?: property.name + val propertyClass = property.returnType + Pair( + deepPropertyName, + fromClassHelper( + propertyClass.jvmErasure, + propertyClass.isMarkedNullable, + deepPropertyName, + propertyDetails, + numDetails, + listDetails, + stringDetails + ) + ) + } + obj( + properties, + emptyList(), + classSchemaDetails.description, + nullable, + classSchemaDetails.title + ) + } + } + } } internal fun toInternal(): Internal { diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/SchemaGenerationTest.kt b/firebase-ai/src/test/java/com/google/firebase/ai/SchemaGenerationTest.kt new file mode 100644 index 00000000000..40d4202149e --- /dev/null +++ b/firebase-ai/src/test/java/com/google/firebase/ai/SchemaGenerationTest.kt @@ -0,0 +1,104 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.firebase.ai + +import com.google.firebase.ai.annotation.ListSchemaDetails +import com.google.firebase.ai.annotation.NumSchemaDetails +import com.google.firebase.ai.annotation.SchemaDetails +import com.google.firebase.ai.annotation.StringSchemaDetails +import com.google.firebase.ai.type.Schema +import com.google.firebase.ai.type.StringFormat +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.junit.Test + +class SchemaGenerationTest { + + @Test + fun testSchemaGeneration() { + val generatedSchema = Schema.fromClass(TestClass1::class) + val schema = + Schema.obj( + description = "A test class (1)", + title = "TestClass1", + properties = + mapOf( + "val1" to Schema.integer("A test field (1)", false, "var1"), + "val2" to Schema.long("A test field (2)", false, "var2", 20.0, 30.0), + "val3" to Schema.boolean("A test field (3)", false, "var3"), + "val4" to Schema.float("A test field (4)", false, "var4"), + "val5" to Schema.double("A test field (5)", false, "var5"), + "val6" to + Schema.string("A test field (6)", false, StringFormat.Custom("StringFormat"), "var6"), + "val7" to + Schema.array( + Schema.obj( + mapOf("val1" to Schema.integer(nullable = true)), + emptyList(), + "A test class (2)", + false, + "TestClass2", + ), + "A test field (7)", + false, + "var7", + 0, + 500, + ), + "val8" to + Schema.obj( + mapOf( + "customSerialName" to Schema.array(Schema.string(), minItems = 0, maxItems = 500) + ), + emptyList(), + "A test field (8)", + false, + "var8", + ), + ), + ) + assert(schema.toInternal() == generatedSchema.toInternal()) + } + + @Serializable + @SchemaDetails("A test class (1)", "TestClass1") + data class TestClass1( + @SchemaDetails("A test field (1)", "var1") val val1: Int, + @NumSchemaDetails(minimum = 20.0, maximum = 30.0) + @SchemaDetails("A test field (2)", "var2") + val val2: Long, + @SchemaDetails("A test field (3)", "var3") val val3: Boolean, + @SchemaDetails("A test field (4)", "var4") val val4: Float, + @SchemaDetails("A test field (5)", "var5") val val5: Double, + @SchemaDetails("A test field (6)", "var6") + @StringSchemaDetails("StringFormat") + val val6: String, + @SchemaDetails("A test field (7)", "var7") + @ListSchemaDetails(0, 500, TestClass2::class) + val val7: List, + @SchemaDetails("A test field (8)", "var8") val val8: TestClass3, + ) + + @Serializable + @SchemaDetails("A test class (2)", "TestClass2") + data class TestClass2(val val1: Int?) + + @Serializable + @SchemaDetails("A test class (3)", "TestClass3") + data class TestClass3( + @ListSchemaDetails(0, 500, String::class) @SerialName("customSerialName") val val1: List + ) +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 7e3fd691d6a..ce52e964167 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -145,6 +145,7 @@ kotlin-bom = { module = "org.jetbrains.kotlin:kotlin-bom", version.ref = "kotlin kotlin-coroutines-tasks = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-play-services", version.ref = "coroutines" } kotlin-stdlib = { module = "org.jetbrains.kotlin:kotlin-stdlib", version.ref = "kotlin" } kotlin-stdlib-jdk8 = { module = "org.jetbrains.kotlin:kotlin-stdlib-jdk8", version.ref = "kotlin" } +kotlin-reflect = { module = "org.jetbrains.kotlin:kotlin-reflect", version.ref = "kotlin" } kotlinx-coroutines-android = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-android", version.ref = "coroutines" } kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "coroutines" } kotlinx-coroutines-reactive = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-reactive", version.ref = "coroutines" }