diff --git a/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessor.kt b/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessor.kt similarity index 54% rename from firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessor.kt rename to firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessor.kt index 05b3ce72594..94b92830ecc 100644 --- a/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessor.kt +++ b/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessor.kt @@ -16,6 +16,7 @@ package com.google.firebase.ai.ksp +import com.google.devtools.ksp.isPublic import com.google.devtools.ksp.processing.CodeGenerator import com.google.devtools.ksp.processing.Dependencies import com.google.devtools.ksp.processing.KSPLogger @@ -25,6 +26,7 @@ import com.google.devtools.ksp.symbol.ClassKind import com.google.devtools.ksp.symbol.KSAnnotated import com.google.devtools.ksp.symbol.KSAnnotation import com.google.devtools.ksp.symbol.KSClassDeclaration +import com.google.devtools.ksp.symbol.KSFunctionDeclaration import com.google.devtools.ksp.symbol.KSType import com.google.devtools.ksp.symbol.KSVisitorVoid import com.google.devtools.ksp.symbol.Modifier @@ -39,12 +41,21 @@ import com.squareup.kotlinpoet.TypeSpec import com.squareup.kotlinpoet.ksp.toClassName import com.squareup.kotlinpoet.ksp.toTypeName import com.squareup.kotlinpoet.ksp.writeTo +import java.util.Locale import javax.annotation.processing.Generated -public class SchemaSymbolProcessor( +public class FirebaseSymbolProcessor( private val codeGenerator: CodeGenerator, private val logger: KSPLogger, ) : SymbolProcessor { + // This regex extracts everything in the kdocs until it hits either the end of the kdocs, or + // the first @ like @property or @see, extracting the main body text of the kdoc + private val baseKdocRegex = Regex("""^\s*(.*?)((@\w* .*)|\z)""", RegexOption.DOT_MATCHES_ALL) + // This regex extracts two capture groups from @property tags, the first is the name of the + // property, and the second is the documentation associated with that property + private val propertyKdocRegex = + Regex("""\s*@property (\w*) (.*?)(?=@\w*|\z)""", RegexOption.DOT_MATCHES_ALL) + override fun process(resolver: Resolver): List { resolver .getSymbolsWithAnnotation("com.google.firebase.ai.annotations.Generable") @@ -52,18 +63,174 @@ public class SchemaSymbolProcessor( .map { it to SchemaSymbolProcessorVisitor() } .forEach { (klass, visitor) -> visitor.visitClassDeclaration(klass, Unit) } + resolver + .getSymbolsWithAnnotation("com.google.firebase.ai.annotations.Tool") + .filterIsInstance() + .map { it to FunctionSymbolProcessorVisitor(it, resolver) } + .forEach { (klass, visitor) -> visitor.visitFunctionDeclaration(klass, Unit) } + return emptyList() } + private inner class FunctionSymbolProcessorVisitor( + private val func: KSFunctionDeclaration, + private val resolver: Resolver, + ) : KSVisitorVoid() { + override fun visitFunctionDeclaration(function: KSFunctionDeclaration, data: Unit) { + var shouldError = false + val fullFunctionName = function.qualifiedName?.asString() + if (fullFunctionName == null) { + logger.warn("Error extracting function name.") + shouldError = true + } + if (!function.isPublic()) { + logger.warn("$fullFunctionName must be public.") + shouldError = true + } + val containingClass = function.parentDeclaration as? KSClassDeclaration + val containingClassQualifiedName = containingClass?.qualifiedName?.asString() + if (containingClassQualifiedName == null) { + logger.warn("Could not extract name of containing class of $fullFunctionName.") + shouldError = true + } + if (containingClass == null || !containingClass.isCompanionObject) { + logger.warn( + "$fullFunctionName must be within a companion object $containingClassQualifiedName." + ) + shouldError = true + } + if (function.parameters.size != 1) { + logger.warn("$fullFunctionName must have exactly one parameter") + shouldError = true + } + val parameter = function.parameters.firstOrNull()?.type?.resolve()?.declaration + if (parameter != null) { + if (parameter.annotations.find { it.shortName.getShortName() == "Generable" } == null) { + logger.warn("$fullFunctionName parameter must be annotated @Generable") + shouldError = true + } + if (parameter.annotations.find { it.shortName.getShortName() == "Serializable" } == null) { + logger.warn("$fullFunctionName parameter must be annotated @Serializable") + shouldError = true + } + } + val output = function.returnType?.resolve() + if ( + output != null && + output.toClassName().canonicalName != "kotlinx.serialization.json.JsonObject" + ) { + if ( + output.declaration.annotations.find { it.shortName.getShortName() != "Generable" } == null + ) { + logger.warn("$fullFunctionName output must be annotated @Generable") + shouldError = true + } + if ( + output.declaration.annotations.find { it.shortName.getShortName() != "Serializable" } == + null + ) { + logger.warn("$fullFunctionName output must be annotated @Serializable") + shouldError = true + } + } + if (shouldError) { + logger.error("$fullFunctionName has one or more errors, please resolve them.") + } + val generatedFunctionFile = generateFileSpec(function) + val containingFile = function.containingFile + if (containingFile == null) { + logger.error("$fullFunctionName must be in a file in the build.") + throw RuntimeException() + } + generatedFunctionFile.writeTo( + codeGenerator, + Dependencies(true, containingFile), + ) + } + + private fun generateFileSpec(functionDeclaration: KSFunctionDeclaration): FileSpec { + val generatedClassName = + functionDeclaration.simpleName.asString().replaceFirstChar { + if (it.isLowerCase()) it.titlecase(Locale.ROOT) else it.toString() + } + "GeneratedFunctionDeclaration" + return FileSpec.builder(functionDeclaration.packageName.asString(), generatedClassName) + .addImport("com.google.firebase.ai.type", "AutoFunctionDeclaration") + .addType( + TypeSpec.classBuilder(generatedClassName) + .addAnnotation(Generated::class) + .addType( + TypeSpec.companionObjectBuilder() + .addProperty( + PropertySpec.builder( + "FUNCTION_DECLARATION", + ClassName("com.google.firebase.ai.type", "AutoFunctionDeclaration") + .parameterizedBy( + functionDeclaration.parameters.first().type.resolve().toClassName(), + functionDeclaration.returnType?.resolve()?.toClassName() + ?: ClassName("kotlinx.serialization.json", "JsonObject") + ), + KModifier.PUBLIC, + ) + .mutable(false) + .initializer( + CodeBlock.builder() + .add(generateCodeBlockForFunctionDeclaration(functionDeclaration)) + .build() + ) + .build() + ) + .build() + ) + .build() + ) + .build() + } + + fun generateCodeBlockForFunctionDeclaration( + functionDeclaration: KSFunctionDeclaration + ): CodeBlock { + val builder = CodeBlock.builder() + val returnType = functionDeclaration.returnType + val hasTypedOutput = + !(returnType == null || + returnType.resolve().toClassName().canonicalName == + "kotlinx.serialization.json.JsonObject") + val kdocDescription = functionDeclaration.docString?.let { extractBaseKdoc(it) } + val annotationDescription = + getStringFromAnnotation( + functionDeclaration.annotations.find { it.shortName.getShortName() == "Tool" }, + "description" + ) + val description = annotationDescription ?: kdocDescription ?: "" + val inputSchemaName = + "${ + functionDeclaration.parameters.first().type.resolve().toClassName().canonicalName + }GeneratedSchema.SCHEMA" + builder + .addStatement("AutoFunctionDeclaration.create(") + .indent() + .addStatement("functionName = %S,", functionDeclaration.simpleName.getShortName()) + .addStatement("description = %S,", description) + .addStatement("inputSchema = $inputSchemaName,") + if (hasTypedOutput) { + val outputSchemaName = + "${ + functionDeclaration.returnType!!.resolve().toClassName().canonicalName + }GeneratedSchema.SCHEMA" + builder.addStatement("outputSchema = $outputSchemaName,") + } + builder.addStatement( + "functionReference = " + + functionDeclaration.qualifiedName!!.getQualifier() + + "::${functionDeclaration.qualifiedName!!.getShortName()}," + ) + builder.unindent().addStatement(")") + return builder.build() + } + } + private inner class SchemaSymbolProcessorVisitor() : KSVisitorVoid() { private val numberTypes = setOf("kotlin.Int", "kotlin.Long", "kotlin.Double", "kotlin.Float") - // This regex extracts everything in the kdocs until it hits either the end of the kdocs, or - // the first @ like @property or @see, extracting the main body text of the kdoc - private val baseKdocRegex = Regex("""^\s*(.*?)((@\w* .*)|\z)""", RegexOption.DOT_MATCHES_ALL) - // This regex extracts two capture groups from @property tags, the first is the name of the - // property, and the second is the documentation associated with that property - private val propertyKdocRegex = - Regex("""\s*@property (\w*) (.*?)(?=@\w*|\z)""", RegexOption.DOT_MATCHES_ALL) override fun visitClassDeclaration(classDeclaration: KSClassDeclaration, data: Unit) { val isDataClass = classDeclaration.modifiers.contains(Modifier.DATA) @@ -76,13 +243,10 @@ public class SchemaSymbolProcessor( throw RuntimeException() } val generatedSchemaFile = generateFileSpec(classDeclaration) - - classDeclaration.containingFile?.let { - generatedSchemaFile.writeTo( - codeGenerator, - Dependencies(true, containingFile), - ) - } + generatedSchemaFile.writeTo( + codeGenerator, + Dependencies(true, containingFile), + ) } fun generateFileSpec(classDeclaration: KSClassDeclaration): FileSpec { @@ -293,73 +457,72 @@ public class SchemaSymbolProcessor( builder.addStatement("nullable = %L)", className.isNullable).unindent() return builder.build() } + } - private fun getDescriptionFromAnnotations( - guideAnnotation: KSAnnotation?, - generableClassAnnotation: KSAnnotation?, - description: String?, - baseKdoc: String?, - ): String? { - val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, "description") - - val guideClassDescription = getStringFromAnnotation(generableClassAnnotation, "description") + private fun getDescriptionFromAnnotations( + guideAnnotation: KSAnnotation?, + generableClassAnnotation: KSAnnotation?, + description: String?, + baseKdoc: String?, + ): String? { + val guidePropertyDescription = getStringFromAnnotation(guideAnnotation, "description") - return guidePropertyDescription ?: guideClassDescription ?: description ?: baseKdoc - } + val guideClassDescription = getStringFromAnnotation(generableClassAnnotation, "description") - private fun getDoubleFromAnnotation( - guideAnnotation: KSAnnotation?, - doubleName: String, - ): Double? { - val guidePropertyDoubleValue = - guideAnnotation - ?.arguments - ?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true } - ?.value as? Double - if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == -1.0) { - return null - } - return guidePropertyDoubleValue + return guidePropertyDescription ?: guideClassDescription ?: description ?: baseKdoc + } + private fun getDoubleFromAnnotation( + guideAnnotation: KSAnnotation?, + doubleName: String, + ): Double? { + val guidePropertyDoubleValue = + guideAnnotation + ?.arguments + ?.firstOrNull { it.name?.getShortName()?.equals(doubleName) == true } + ?.value as? Double + if (guidePropertyDoubleValue == null || guidePropertyDoubleValue == -1.0) { + return null } + return guidePropertyDoubleValue + } - private fun getIntFromAnnotation(guideAnnotation: KSAnnotation?, intName: String): Int? { - val guidePropertyIntValue = - guideAnnotation - ?.arguments - ?.firstOrNull { it.name?.getShortName()?.equals(intName) == true } - ?.value as? Int - if (guidePropertyIntValue == null || guidePropertyIntValue == -1) { - return null - } - return guidePropertyIntValue + private fun getIntFromAnnotation(guideAnnotation: KSAnnotation?, intName: String): Int? { + val guidePropertyIntValue = + guideAnnotation + ?.arguments + ?.firstOrNull { it.name?.getShortName()?.equals(intName) == true } + ?.value as? Int + if (guidePropertyIntValue == null || guidePropertyIntValue == -1) { + return null } + return guidePropertyIntValue + } - private fun getStringFromAnnotation( - guideAnnotation: KSAnnotation?, - stringName: String, - ): String? { - val guidePropertyStringValue = - guideAnnotation - ?.arguments - ?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true } - ?.value as? String - if (guidePropertyStringValue.isNullOrEmpty()) { - return null - } - return guidePropertyStringValue + private fun getStringFromAnnotation( + guideAnnotation: KSAnnotation?, + stringName: String, + ): String? { + val guidePropertyStringValue = + guideAnnotation + ?.arguments + ?.firstOrNull { it.name?.getShortName()?.equals(stringName) == true } + ?.value as? String + if (guidePropertyStringValue.isNullOrEmpty()) { + return null } + return guidePropertyStringValue + } - private fun extractBaseKdoc(kdoc: String): String? { - return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1)?.value?.trim().let { - if (it.isNullOrEmpty()) null else it - } + private fun extractBaseKdoc(kdoc: String): String? { + return baseKdocRegex.matchEntire(kdoc)?.groups?.get(1)?.value?.trim().let { + if (it.isNullOrEmpty()) null else it } + } - private fun extractPropertyKdocs(kdoc: String): Map { - return propertyKdocRegex - .findAll(kdoc) - .map { it.groups[1]!!.value to it.groups[2]!!.value.replace("\n", "").trim() } - .toMap() - } + private fun extractPropertyKdocs(kdoc: String): Map { + return propertyKdocRegex + .findAll(kdoc) + .map { it.groups[1]!!.value to it.groups[2]!!.value.replace("\n", "").trim() } + .toMap() } } diff --git a/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessorProvider.kt b/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessorProvider.kt similarity index 85% rename from firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessorProvider.kt rename to firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessorProvider.kt index 2c8015bc8a9..771706d9866 100644 --- a/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/SchemaSymbolProcessorProvider.kt +++ b/firebase-ai-ksp-processor/src/main/kotlin/com/google/firebase/ai/ksp/FirebaseSymbolProcessorProvider.kt @@ -20,8 +20,8 @@ import com.google.devtools.ksp.processing.SymbolProcessor import com.google.devtools.ksp.processing.SymbolProcessorEnvironment import com.google.devtools.ksp.processing.SymbolProcessorProvider -public class SchemaSymbolProcessorProvider : SymbolProcessorProvider { +public class FirebaseSymbolProcessorProvider : SymbolProcessorProvider { override fun create(environment: SymbolProcessorEnvironment): SymbolProcessor { - return SchemaSymbolProcessor(environment.codeGenerator, environment.logger) + return FirebaseSymbolProcessor(environment.codeGenerator, environment.logger) } } diff --git a/firebase-ai-ksp-processor/src/main/resources/META-INF/services/com.google.devtools.ksp.processing.SymbolProcessorProvider b/firebase-ai-ksp-processor/src/main/resources/META-INF/services/com.google.devtools.ksp.processing.SymbolProcessorProvider index 83d92f28c7e..b5a8cffc5a6 100644 --- a/firebase-ai-ksp-processor/src/main/resources/META-INF/services/com.google.devtools.ksp.processing.SymbolProcessorProvider +++ b/firebase-ai-ksp-processor/src/main/resources/META-INF/services/com.google.devtools.ksp.processing.SymbolProcessorProvider @@ -1 +1 @@ -com.google.firebase.ai.ksp.SchemaSymbolProcessorProvider \ No newline at end of file +com.google.firebase.ai.ksp.FirebaseSymbolProcessorProvider \ No newline at end of file diff --git a/firebase-ai/api.txt b/firebase-ai/api.txt index 01ac626e1cb..a640363da77 100644 --- a/firebase-ai/api.txt +++ b/firebase-ai/api.txt @@ -124,6 +124,11 @@ package com.google.firebase.ai.annotations { property public abstract String pattern; } + @kotlin.annotation.Retention(kotlin.annotation.AnnotationRetention.SOURCE) @kotlin.annotation.Target(allowedTargets=kotlin.annotation.AnnotationTarget.FUNCTION) public @interface Tool { + method public abstract String description() default ""; + property public abstract String description; + } + } package com.google.firebase.ai.java { @@ -438,6 +443,10 @@ package com.google.firebase.ai.type { public abstract class FirebaseAIException extends java.lang.RuntimeException { } + public final class FirebaseAutoFunctionException extends com.google.firebase.ai.type.FirebaseAIException { + ctor public FirebaseAutoFunctionException(String message); + } + public final class FunctionCallPart implements com.google.firebase.ai.type.Part { ctor public FunctionCallPart(String name, java.util.Map args); ctor public FunctionCallPart(String name, java.util.Map args, String? id = null); @@ -1260,6 +1269,7 @@ package com.google.firebase.ai.type { public final class RequestOptions { ctor public RequestOptions(); ctor public RequestOptions(long timeoutInMillis = 180.seconds.inWholeMilliseconds); + ctor public RequestOptions(long timeoutInMillis = 180.seconds.inWholeMilliseconds, int autoFunctionCallingTurnLimit = 10); } public final class RequestTimeoutException extends com.google.firebase.ai.type.FirebaseAIException { diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/Chat.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/Chat.kt index 64356508d58..60255ccac00 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/Chat.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/Chat.kt @@ -22,6 +22,7 @@ import com.google.firebase.ai.type.FunctionCallPart import com.google.firebase.ai.type.FunctionResponsePart import com.google.firebase.ai.type.GenerateContentResponse import com.google.firebase.ai.type.InvalidStateException +import com.google.firebase.ai.type.RequestTimeoutException import com.google.firebase.ai.type.TextPart import com.google.firebase.ai.type.content import java.util.LinkedList @@ -50,6 +51,7 @@ public class Chat( public val history: MutableList = ArrayList() ) { private var lock = Semaphore(1) + private var turns: Int = 0 /** * Sends a message using the provided [prompt]; automatically providing the existing [history] as @@ -67,24 +69,33 @@ public class Chat( prompt.assertComesFromUser() attemptLock() var response: GenerateContentResponse - var tempPrompt = prompt try { + val tempHistory = mutableListOf(prompt) while (true) { - response = model.generateContent(listOf(*history.toTypedArray(), tempPrompt)) - val responsePart = response.candidates.first().content.parts.first() + response = + model.generateContent(listOf(*history.toTypedArray(), *tempHistory.toTypedArray())) + tempHistory.add(response.candidates.first().content) + val functionCallParts = + response.candidates.first().content.parts.filterIsInstance() - history.add(tempPrompt) - history.add(response.candidates.first().content) - if (responsePart is FunctionCallPart && model.hasFunction(responsePart)) { - val output = model.executeFunction(responsePart) - tempPrompt = Content("function", listOf(FunctionResponsePart(responsePart.name, output))) + if (functionCallParts.isNotEmpty()) { + if (model.getTurnLimit() < ++turns) { + throw RequestTimeoutException("Request took too many turns", history = tempHistory) + } + if (functionCallParts.all { model.hasFunction(it) }) { + val functionResponsePart = + functionCallParts.map { FunctionResponsePart(it.name, model.executeFunction(it)) } + tempHistory.add(Content("function", functionResponsePart)) + } } else { break } } + history.addAll(tempHistory) return response } finally { lock.release() + turns = 0 } } @@ -153,6 +164,7 @@ public class Chat( return flow .transform { response -> automaticFunctionExecutingTransform(this, tempHistory, response) } .onCompletion { + turns = 0 lock.release() if (it == null) { history.addAll(tempHistory) @@ -210,22 +222,7 @@ public class Chat( addTextToHistory(tempHistory, part) } is FunctionCallPart -> { - val functionCall = - response.candidates.first().content.parts.first { it is FunctionCallPart } - as FunctionCallPart - if (model.hasFunction(functionCall)) { - val output = model.executeFunction(functionCall) - val functionResponse = - Content("function", listOf(FunctionResponsePart(functionCall.name, output))) - tempHistory.add(response.candidates.first().content) - tempHistory.add(functionResponse) - model - .generateContentStream(listOf(*history.toTypedArray(), *tempHistory.toTypedArray())) - .collect { automaticFunctionExecutingTransform(transformer, tempHistory, it) } - } else { - transformer.emit(response) - tempHistory.add(Content("model", listOf(part))) - } + // do nothing } else -> { transformer.emit(response) @@ -233,6 +230,28 @@ public class Chat( } } } + val functionCallParts = + response.candidates.first().content.parts.filterIsInstance() + if (functionCallParts.isNotEmpty()) { + if (functionCallParts.all { model.hasFunction(it) }) { + if (model.getTurnLimit() < ++turns) { + throw RequestTimeoutException("Request took too many turns", history = tempHistory) + } + val functionResponses = + Content( + "function", + functionCallParts.map { FunctionResponsePart(it.name, model.executeFunction(it)) } + ) + tempHistory.add(Content("model", functionCallParts)) + tempHistory.add(functionResponses) + model + .generateContentStream(listOf(*history.toTypedArray(), *tempHistory.toTypedArray())) + .collect { automaticFunctionExecutingTransform(transformer, tempHistory, it) } + } else { + transformer.emit(response) + tempHistory.add(Content("model", functionCallParts)) + } + } } private fun addTextToHistory(tempHistory: MutableList, textPart: TextPart) { diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/GenerativeModel.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/GenerativeModel.kt index 1eef17e7aa5..6b92d16fc33 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/GenerativeModel.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/GenerativeModel.kt @@ -27,6 +27,7 @@ import com.google.firebase.ai.type.Content import com.google.firebase.ai.type.CountTokensResponse import com.google.firebase.ai.type.FinishReason import com.google.firebase.ai.type.FirebaseAIException +import com.google.firebase.ai.type.FirebaseAutoFunctionException import com.google.firebase.ai.type.FunctionCallPart import com.google.firebase.ai.type.GenerateContentResponse import com.google.firebase.ai.type.GenerateObjectResponse @@ -52,6 +53,7 @@ import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.jsonObject import kotlinx.serialization.serializerOrNull @@ -68,7 +70,7 @@ internal constructor( private val toolConfig: ToolConfig? = null, private val systemInstruction: Content? = null, private val generativeBackend: GenerativeBackend = GenerativeBackend.googleAI(), - private val controller: APIController, + internal val controller: APIController, ) { internal constructor( modelName: String, @@ -324,7 +326,7 @@ internal constructor( internal fun hasFunction(call: FunctionCallPart): Boolean { return tools ?.flatMap { it.autoFunctionDeclarations?.filterNotNull() ?: emptyList() } - ?.firstOrNull { it.name == call.name } != null + ?.firstOrNull { it.name == call.name && it.functionReference != null } != null } @OptIn(InternalSerializationApi::class) @@ -356,14 +358,20 @@ internal constructor( val functionReference = functionDeclaration.functionReference ?: throw RuntimeException("Function reference for ${functionDeclaration.name} is missing") - val output = functionReference.invoke(input) - val outputSerializer = functionDeclaration.outputSchema?.clazz?.serializerOrNull() - if (outputSerializer != null) { - return Json.encodeToJsonElement(outputSerializer, output).jsonObject + try { + val output = functionReference.invoke(input) + val outputSerializer = functionDeclaration.outputSchema?.clazz?.serializerOrNull() + if (outputSerializer != null) { + return Json.encodeToJsonElement(outputSerializer, output).jsonObject + } + return output as JsonObject + } catch (e: FirebaseAutoFunctionException) { + return JsonObject(mapOf("error" to JsonPrimitive(e.message))) } - return output as JsonObject } + internal fun getTurnLimit(): Int = controller.getTurnLimit() + @OptIn(ExperimentalSerializationApi::class) private fun constructRequest(overrideConfig: GenerationConfig? = null, vararg prompt: Content) = GenerateContentRequest( diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Generable.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Generable.kt index 64b18ca071b..afc5256eef1 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Generable.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Generable.kt @@ -16,8 +16,10 @@ package com.google.firebase.ai.annotations +import com.google.firebase.ai.type.JsonSchema + /** - * This annotation is used with the firebase-ai-ksp-processor plugin to generate JsonSchema that + * This annotation is used with the firebase-ai-ksp-processor plugin to generate [JsonSchema] that * match an existing kotlin class structure. For more info see: * https://github.com/firebase/firebase-android-sdk/blob/main/firebase-ai-ksp-processor/README.md * diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Tool.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Tool.kt new file mode 100644 index 00000000000..a0a81d2ff57 --- /dev/null +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/annotations/Tool.kt @@ -0,0 +1,30 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai.annotations + +import com.google.firebase.ai.type.AutoFunctionDeclaration + +/** + * This annotation is used with the firebase-ai-ksp-processor plugin to generate + * [AutoFunctionDeclaration]s that match an existing kotlin function. For more info see: + * https://github.com/firebase/firebase-android-sdk/blob/main/firebase-ai-ksp-processor/README.md + * + * @property description a description of the function + */ +@Target(AnnotationTarget.FUNCTION) +@Retention(AnnotationRetention.SOURCE) +public annotation class Tool(public val description: String = "") diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt index e992f92e674..ac2471f4b8d 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt @@ -227,6 +227,8 @@ internal constructor( throw FirebaseAIException.from(e) } + fun getTurnLimit(): Int = requestOptions.autoFunctionCallingTurnLimit + private fun getBidiEndpoint(location: String): String = when (backend?.backend) { GenerativeBackendEnum.VERTEX_AI, diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Exceptions.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Exceptions.kt index fed7660d08e..5af6d9416dc 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Exceptions.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Exceptions.kt @@ -156,8 +156,11 @@ internal constructor(public val response: GenerateContentResponse, cause: Throwa * Usually occurs due to a user specified [timeout][RequestOptions.timeout]. */ public class RequestTimeoutException -internal constructor(message: String, cause: Throwable? = null) : - FirebaseAIException(message, cause) +internal constructor( + message: String, + cause: Throwable? = null, + history: List = emptyList() +) : FirebaseAIException(message, cause) /** * The specified Vertex AI location is invalid. @@ -204,6 +207,9 @@ public class ServiceConnectionHandshakeFailedException(message: String, cause: T public class PermissionMissingException(message: String, cause: Throwable? = null) : FirebaseAIException(message, cause) +/** Thrown when a function invoked by the model has an error that should be returned to the model */ +public class FirebaseAutoFunctionException(message: String) : FirebaseAIException(message) + /** Catch all case for exceptions not explicitly expected. */ public class UnknownException internal constructor(message: String, cause: Throwable? = null) : FirebaseAIException(message, cause) diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/RequestOptions.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/RequestOptions.kt index dc4211e7222..a1e3f6020ca 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/RequestOptions.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/RequestOptions.kt @@ -27,6 +27,7 @@ internal constructor( internal val timeout: Duration, internal val endpoint: String = "https://firebasevertexai.googleapis.com", internal val apiVersion: String = "v1beta", + internal val autoFunctionCallingTurnLimit: Int ) { /** @@ -38,7 +39,9 @@ internal constructor( @JvmOverloads public constructor( timeoutInMillis: Long = 180.seconds.inWholeMilliseconds, + autoFunctionCallingTurnLimit: Int = 10 ) : this( timeout = timeoutInMillis.toDuration(DurationUnit.MILLISECONDS), + autoFunctionCallingTurnLimit = autoFunctionCallingTurnLimit ) } diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/GenerativeModelTesting.kt b/firebase-ai/src/test/java/com/google/firebase/ai/GenerativeModelTesting.kt index b84f89fd223..45c7aab7a80 100644 --- a/firebase-ai/src/test/java/com/google/firebase/ai/GenerativeModelTesting.kt +++ b/firebase-ai/src/test/java/com/google/firebase/ai/GenerativeModelTesting.kt @@ -82,7 +82,11 @@ internal class GenerativeModelTesting { APIController( "super_cool_test_key", "gemini-2.5-flash", - RequestOptions(timeout = 5.seconds, endpoint = "https://my.custom.endpoint"), + RequestOptions( + timeout = 5.seconds, + endpoint = "https://my.custom.endpoint", + autoFunctionCallingTurnLimit = 10 + ), mockEngine, TEST_CLIENT_ID, mockFirebaseApp, diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/common/APIControllerTests.kt b/firebase-ai/src/test/java/com/google/firebase/ai/common/APIControllerTests.kt index f3e818085f7..bf2ef116d12 100644 --- a/firebase-ai/src/test/java/com/google/firebase/ai/common/APIControllerTests.kt +++ b/firebase-ai/src/test/java/com/google/firebase/ai/common/APIControllerTests.kt @@ -86,7 +86,7 @@ internal class APIControllerTests { @Test fun `(generateContent) respects a custom timeout`() = - commonTest(requestOptions = RequestOptions(2.seconds)) { + commonTest(requestOptions = RequestOptions(2.seconds.inWholeMilliseconds, 10)) { shouldThrow { withTimeout(testTimeout) { apiController.generateContent(textGenerateContentRequest("test")) @@ -146,7 +146,11 @@ internal class RequestFormatTests { APIController( "super_cool_test_key", "gemini-pro-2.5", - RequestOptions(timeout = 5.seconds, endpoint = "https://my.custom.endpoint"), + RequestOptions( + timeout = 5.seconds, + endpoint = "https://my.custom.endpoint", + autoFunctionCallingTurnLimit = 10 + ), mockEngine, TEST_CLIENT_ID, mockFirebaseApp,