diff --git a/firebase-ai/api.txt b/firebase-ai/api.txt index f73c51d7112..b5932bf9b0e 100644 --- a/firebase-ai/api.txt +++ b/firebase-ai/api.txt @@ -36,6 +36,10 @@ package com.google.firebase.ai { method @com.google.firebase.ai.type.PublicPreviewAPI public com.google.firebase.ai.LiveGenerativeModel liveModel(String modelName, com.google.firebase.ai.type.LiveGenerationConfig? generationConfig = null, java.util.List? tools = null); method @com.google.firebase.ai.type.PublicPreviewAPI public com.google.firebase.ai.LiveGenerativeModel liveModel(String modelName, com.google.firebase.ai.type.LiveGenerationConfig? generationConfig = null, java.util.List? tools = null, com.google.firebase.ai.type.Content? systemInstruction = null); method @com.google.firebase.ai.type.PublicPreviewAPI public com.google.firebase.ai.LiveGenerativeModel liveModel(String modelName, com.google.firebase.ai.type.LiveGenerationConfig? generationConfig = null, java.util.List? tools = null, com.google.firebase.ai.type.Content? systemInstruction = null, com.google.firebase.ai.type.RequestOptions requestOptions = com.google.firebase.ai.type.RequestOptions()); + method @com.google.firebase.ai.type.PublicPreviewAPI public com.google.firebase.ai.TemplateGenerativeModel templateGenerativeModel(); + method @com.google.firebase.ai.type.PublicPreviewAPI public com.google.firebase.ai.TemplateGenerativeModel templateGenerativeModel(com.google.firebase.ai.type.RequestOptions requestOptions = com.google.firebase.ai.type.RequestOptions()); + method @com.google.firebase.ai.type.PublicPreviewAPI public com.google.firebase.ai.TemplateImagenModel templateImagenModel(); + method @com.google.firebase.ai.type.PublicPreviewAPI public com.google.firebase.ai.TemplateImagenModel templateImagenModel(com.google.firebase.ai.type.RequestOptions requestOptions = com.google.firebase.ai.type.RequestOptions()); property public static final com.google.firebase.ai.FirebaseAI instance; field public static final com.google.firebase.ai.FirebaseAI.Companion Companion; } @@ -83,6 +87,15 @@ package com.google.firebase.ai { method public suspend Object? connect(kotlin.coroutines.Continuation); } + @com.google.firebase.ai.type.PublicPreviewAPI public final class TemplateGenerativeModel { + method public suspend Object? generateContent(String templateId, java.util.Map inputs, kotlin.coroutines.Continuation); + method public kotlinx.coroutines.flow.Flow generateContentStream(String templateId, java.util.Map inputs); + } + + @com.google.firebase.ai.type.PublicPreviewAPI public final class TemplateImagenModel { + method public suspend Object? generateImages(String templateId, java.util.Map inputs, kotlin.coroutines.Continuation>); + } + } package com.google.firebase.ai.java { @@ -166,6 +179,29 @@ package com.google.firebase.ai.java { method public com.google.firebase.ai.java.LiveSessionFutures from(com.google.firebase.ai.type.LiveSession session); } + public abstract class TemplateGenerativeModelFutures { + method public static final com.google.firebase.ai.java.TemplateGenerativeModelFutures from(com.google.firebase.ai.TemplateGenerativeModel model); + method public abstract com.google.common.util.concurrent.ListenableFuture generateContent(String templateId, java.util.Map inputs); + method public abstract org.reactivestreams.Publisher generateContentStream(String templateId, java.util.Map inputs); + method public abstract com.google.firebase.ai.TemplateGenerativeModel getGenerativeModel(); + field public static final com.google.firebase.ai.java.TemplateGenerativeModelFutures.Companion Companion; + } + + public static final class TemplateGenerativeModelFutures.Companion { + method public com.google.firebase.ai.java.TemplateGenerativeModelFutures from(com.google.firebase.ai.TemplateGenerativeModel model); + } + + public abstract class TemplateImagenModelFutures { + method public static final com.google.firebase.ai.java.TemplateImagenModelFutures from(com.google.firebase.ai.TemplateImagenModel model); + method public abstract com.google.common.util.concurrent.ListenableFuture> generateImages(String templateId, java.util.Map inputs); + method public abstract com.google.firebase.ai.TemplateImagenModel getImageModel(); + field public static final com.google.firebase.ai.java.TemplateImagenModelFutures.Companion Companion; + } + + public static final class TemplateImagenModelFutures.Companion { + method public com.google.firebase.ai.java.TemplateImagenModelFutures from(com.google.firebase.ai.TemplateImagenModel model); + } + } package com.google.firebase.ai.type { diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/FirebaseAI.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/FirebaseAI.kt index dd2309c984a..52eeed8959b 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/FirebaseAI.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/FirebaseAI.kt @@ -106,6 +106,29 @@ internal constructor( ) } + /** + * Instantiates a new [TemplateGenerativeModel] given the provided parameters. + * + * @param requestOptions Configuration options for sending requests to the backend. + * @return The initialized [TemplateGenerativeModel] instance. + */ + @JvmOverloads + @PublicPreviewAPI + public fun templateGenerativeModel( + requestOptions: RequestOptions = RequestOptions(), + ): TemplateGenerativeModel { + val templateUri = getTemplateUri(backend) + return TemplateGenerativeModel( + templateUri, + firebaseApp.options.apiKey, + firebaseApp, + useLimitedUseAppCheckTokens, + requestOptions, + appCheckProvider.get(), + internalAuthProvider.get(), + ) + } + /** * Instantiates a new [LiveGenerationConfig] given the provided parameters. * @@ -205,6 +228,29 @@ internal constructor( ) } + /** + * Instantiates a new [TemplateImagenModel] given the provided parameters. + * + * @param requestOptions Configuration options for sending requests to the backend. + * @return The initialized [TemplateImagenModel] instance. + */ + @JvmOverloads + @PublicPreviewAPI + public fun templateImagenModel( + requestOptions: RequestOptions = RequestOptions(), + ): TemplateImagenModel { + val templateUri = getTemplateUri(backend) + return TemplateImagenModel( + templateUri, + firebaseApp.options.apiKey, + firebaseApp, + useLimitedUseAppCheckTokens, + requestOptions, + appCheckProvider.get(), + internalAuthProvider.get(), + ) + } + public companion object { /** The [FirebaseAI] instance for the default [FirebaseApp] using the Google AI Backend. */ @JvmStatic @@ -258,6 +304,13 @@ internal constructor( private val TAG = FirebaseAI::class.java.simpleName } + + private fun getTemplateUri(backend: GenerativeBackend): String = + when (backend.backend) { + GenerativeBackendEnum.VERTEX_AI -> + "projects/${firebaseApp.options.projectId}/locations/${backend.location}/templates/" + GenerativeBackendEnum.GOOGLE_AI -> "projects/${firebaseApp.options.projectId}/templates/" + } } /** The [FirebaseAI] instance for the default [FirebaseApp] using the Google AI Backend. */ diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/ImagenModel.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/ImagenModel.kt index 62f11319f68..c562ea0233b 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/ImagenModel.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/ImagenModel.kt @@ -233,7 +233,7 @@ internal constructor( } @OptIn(PublicPreviewAPI::class) -private fun ImagenGenerationResponse.Internal.validate(): ImagenGenerationResponse.Internal { +internal fun ImagenGenerationResponse.Internal.validate(): ImagenGenerationResponse.Internal { if (predictions.none { it.mimeType != null }) { throw ContentBlockedException( message = predictions.first { it.raiFilteredReason != null }.raiFilteredReason diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateGenerativeModel.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateGenerativeModel.kt new file mode 100644 index 00000000000..8672813fff3 --- /dev/null +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateGenerativeModel.kt @@ -0,0 +1,141 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai + +import com.google.firebase.FirebaseApp +import com.google.firebase.ai.common.APIController +import com.google.firebase.ai.common.AppCheckHeaderProvider +import com.google.firebase.ai.common.TemplateGenerateContentRequest +import com.google.firebase.ai.type.Content +import com.google.firebase.ai.type.FinishReason +import com.google.firebase.ai.type.FirebaseAIException +import com.google.firebase.ai.type.GenerateContentResponse +import com.google.firebase.ai.type.PromptBlockedException +import com.google.firebase.ai.type.PublicPreviewAPI +import com.google.firebase.ai.type.RequestOptions +import com.google.firebase.ai.type.ResponseStoppedException +import com.google.firebase.ai.type.SerializationException +import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider +import com.google.firebase.auth.internal.InternalAuthProvider +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.flow.map +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.jsonObject +import org.json.JSONObject + +/** + * Represents a multimodal model (like Gemini), capable of generating content based on various + * templated input types. + */ +@PublicPreviewAPI +public class TemplateGenerativeModel +internal constructor( + private val templateUri: String, + private val controller: APIController, +) { + + internal constructor( + templateUri: String, + apiKey: String, + firebaseApp: FirebaseApp, + useLimitedUseAppCheckTokens: Boolean, + requestOptions: RequestOptions = RequestOptions(), + appCheckTokenProvider: InteropAppCheckTokenProvider? = null, + internalAuthProvider: InternalAuthProvider? = null + ) : this( + templateUri, + APIController( + apiKey, + "", + requestOptions, + "gl-kotlin/${KotlinVersion.CURRENT}-ai fire/${BuildConfig.VERSION_NAME}", + firebaseApp, + AppCheckHeaderProvider( + TAG, + useLimitedUseAppCheckTokens, + appCheckTokenProvider, + internalAuthProvider + ), + ), + ) + + /** + * Generates content from a prompt template and inputs. + * + * @param templateId The ID of the prompt template to use. + * @param inputs A map of variables to substitute into the template. + * @return The content generated by the model. + * @throws [FirebaseAIException] if the request failed. + * @see [FirebaseAIException] for types of errors. + */ + public suspend fun generateContent( + templateId: String, + inputs: Map, + ): GenerateContentResponse = + try { + controller + .templateGenerateContent("$templateUri$templateId", constructRequest(inputs)) + .toPublic() + .validate() + } catch (e: Throwable) { + throw FirebaseAIException.from(e) + } + + /** + * Generates content as a stream from a prompt template and inputs. + * + * @param templateId The ID of the prompt template to use. + * @param inputs A map of variables to substitute into the template. + * @return A [Flow] which will emit responses as they are returned by the model. + * @throws [FirebaseAIException] if the request failed. + * @see [FirebaseAIException] for types of errors. + */ + public fun generateContentStream( + templateId: String, + inputs: Map + ): Flow = + controller + .templateGenerateContentStream("$templateUri$templateId", constructRequest(inputs)) + .catch { throw FirebaseAIException.from(it) } + .map { it.toPublic().validate() } + + internal fun constructRequest( + inputs: Map, + history: List? = null + ): TemplateGenerateContentRequest { + return TemplateGenerateContentRequest( + Json.parseToJsonElement(JSONObject(inputs).toString()).jsonObject, + history?.let { it.map { it.toTemplateInternal() } } + ) + } + + private fun GenerateContentResponse.validate() = apply { + if (candidates.isEmpty() && promptFeedback == null) { + throw SerializationException("Error deserializing response, found no valid fields") + } + promptFeedback?.blockReason?.let { throw PromptBlockedException(this) } + candidates + .mapNotNull { it.finishReason } + .firstOrNull { it != FinishReason.STOP } + ?.let { throw ResponseStoppedException(this) } + } + + private companion object { + private val TAG = TemplateGenerativeModel::class.java.simpleName + } +} diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateImagenModel.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateImagenModel.kt new file mode 100644 index 00000000000..e6281f49519 --- /dev/null +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateImagenModel.kt @@ -0,0 +1,106 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai + +import com.google.firebase.FirebaseApp +import com.google.firebase.ai.common.APIController +import com.google.firebase.ai.common.AppCheckHeaderProvider +import com.google.firebase.ai.common.TemplateGenerateImageRequest +import com.google.firebase.ai.type.FirebaseAIException +import com.google.firebase.ai.type.ImagenGenerationResponse +import com.google.firebase.ai.type.ImagenInlineImage +import com.google.firebase.ai.type.PublicPreviewAPI +import com.google.firebase.ai.type.RequestOptions +import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider +import com.google.firebase.auth.internal.InternalAuthProvider +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.jsonObject +import org.json.JSONObject + +/** + * Represents a generative model (like Imagen), capable of generating images based a template. + * + * See the documentation for a list of + * [supported models](https://firebase.google.com/docs/ai-logic/models). + */ +@PublicPreviewAPI +public class TemplateImagenModel +internal constructor( + private val templateUri: String, + private val controller: APIController, +) { + + @JvmOverloads + internal constructor( + templateUri: String, + apiKey: String, + firebaseApp: FirebaseApp, + useLimitedUseAppCheckTokens: Boolean, + requestOptions: RequestOptions = RequestOptions(), + appCheckTokenProvider: InteropAppCheckTokenProvider? = null, + internalAuthProvider: InternalAuthProvider? = null, + ) : this( + templateUri, + APIController( + apiKey, + "", + requestOptions, + "gl-kotlin/${KotlinVersion.CURRENT}-ai fire/${BuildConfig.VERSION_NAME}", + firebaseApp, + AppCheckHeaderProvider( + TAG, + useLimitedUseAppCheckTokens, + appCheckTokenProvider, + internalAuthProvider + ), + ), + ) + + /** + * Generates an image, returning the result directly to the caller. + * + * @param templateId The ID of server prompt template. + * @param inputs the inputs needed to fill in the prompt + */ + public suspend fun generateImages( + templateId: String, + inputs: Map + ): ImagenGenerationResponse = + try { + controller + .templateGenerateImage( + "$templateUri$templateId", + constructTemplateGenerateImageRequest(inputs) + ) + .validate() + .toPublicInline() + } catch (e: Throwable) { + throw FirebaseAIException.from(e) + } + + private fun constructTemplateGenerateImageRequest( + inputs: Map + ): TemplateGenerateImageRequest { + return TemplateGenerateImageRequest( + Json.parseToJsonElement(JSONObject(inputs).toString()).jsonObject + ) + } + + internal companion object { + private val TAG = TemplateImagenModel::class.java.simpleName + } +} diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt index 220e5efedac..e992f92e674 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt @@ -164,6 +164,38 @@ internal constructor( throw FirebaseAIException.from(e) } + suspend fun templateGenerateContent( + templateId: String, + request: TemplateGenerateContentRequest + ): GenerateContentResponse.Internal = + try { + client + .post( + "${requestOptions.endpoint}/${requestOptions.apiVersion}/$templateId:templateGenerateContent" + ) { + applyCommonConfiguration(request) + applyHeaderProvider() + } + .also { validateResponse(it) } + .body() + .validate() + } catch (e: Throwable) { + throw FirebaseAIException.from(e) + } + + fun templateGenerateContentStream( + templateId: String, + request: TemplateGenerateContentRequest + ): Flow = + client + .postStream( + "${requestOptions.endpoint}/${requestOptions.apiVersion}/$templateId:templateStreamGenerateContent?alt=sse" + ) { + applyCommonConfiguration(request) + } + .map { it.validate() } + .catch { throw FirebaseAIException.from(it) } + suspend fun generateImage(request: GenerateImageRequest): ImagenGenerationResponse.Internal = try { client @@ -177,6 +209,24 @@ internal constructor( throw FirebaseAIException.from(e) } + suspend fun templateGenerateImage( + templateId: String, + request: TemplateGenerateImageRequest + ): ImagenGenerationResponse.Internal = + try { + client + .post( + "${requestOptions.endpoint}/${requestOptions.apiVersion}/$templateId:templatePredict" + ) { + applyCommonConfiguration(request) + applyHeaderProvider() + } + .also { validateResponse(it) } + .body() + } catch (e: Throwable) { + throw FirebaseAIException.from(e) + } + private fun getBidiEndpoint(location: String): String = when (backend?.backend) { GenerativeBackendEnum.VERTEX_AI, @@ -228,6 +278,8 @@ internal constructor( is GenerateContentRequest -> setBody(request) is CountTokensRequest -> setBody(request) is GenerateImageRequest -> setBody(request) + is TemplateGenerateContentRequest -> setBody(request) + is TemplateGenerateImageRequest -> setBody(request) } applyCommonHeaders() } diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt index bb6bf242bb0..e3a4afcb56c 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt @@ -31,6 +31,7 @@ import com.google.firebase.ai.type.ToolConfig import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject internal interface Request @@ -45,6 +46,14 @@ internal data class GenerateContentRequest( @SerialName("system_instruction") val systemInstruction: Content.Internal? = null, ) : Request +@Serializable +internal data class TemplateGenerateContentRequest( + val inputs: JsonObject, + val history: List? +) : Request + +@Serializable internal data class TemplateGenerateImageRequest(val inputs: JsonObject) : Request + @Serializable internal data class CountTokensRequest( val generateContentRequest: GenerateContentRequest? = null, diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/java/TemplateGenerativeModelFutures.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/java/TemplateGenerativeModelFutures.kt new file mode 100644 index 00000000000..f1045e3db1d --- /dev/null +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/java/TemplateGenerativeModelFutures.kt @@ -0,0 +1,98 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai.java + +import androidx.concurrent.futures.SuspendToFutureAdapter +import com.google.common.util.concurrent.ListenableFuture +import com.google.firebase.ai.TemplateGenerativeModel +import com.google.firebase.ai.type.FirebaseAIException +import com.google.firebase.ai.type.GenerateContentResponse +import com.google.firebase.ai.type.PublicPreviewAPI +import kotlinx.coroutines.reactive.asPublisher +import org.reactivestreams.Publisher + +/** + * Wrapper class providing Java compatible methods for [TemplateGenerativeModel]. + * + * @see [TemplateGenerativeModel] + */ +@OptIn(PublicPreviewAPI::class) +public abstract class TemplateGenerativeModelFutures internal constructor() { + + /** + * Generates new content using the given templateId with the given inputs. + * + * @param templateId The ID of server prompt template. + * @param inputs the inputs needed to fill in the prompt + * @return The content generated by the model. + * @throws [FirebaseAIException] if the request failed. + * @see [FirebaseAIException] for types of errors. + */ + public abstract fun generateContent( + templateId: String, + inputs: Map + ): ListenableFuture + + /** + * Generates new content as a stream using the given templateId with the given inputs. + * + * @param templateId The ID of server prompt template. + * @param inputs the inputs needed to fill in the prompt + * @return A [Publisher] which will emit responses as they are returned by the model. + * @throws [FirebaseAIException] if the request failed. + * @see [FirebaseAIException] for types of errors. + */ + public abstract fun generateContentStream( + templateId: String, + inputs: Map + ): Publisher + + /** Returns the [TemplateGenerativeModel] object wrapped by this object. */ + public abstract fun getGenerativeModel(): TemplateGenerativeModel + + private class FuturesImpl(private val model: TemplateGenerativeModel) : + TemplateGenerativeModelFutures() { + override fun generateContent( + templateId: String, + inputs: Map + ): ListenableFuture { + return SuspendToFutureAdapter.launchFuture { model.generateContent(templateId, inputs) } + } + + override fun generateContentStream( + templateId: String, + inputs: Map + ): Publisher { + return model.generateContentStream(templateId, inputs).asPublisher() + } + + override fun getGenerativeModel(): TemplateGenerativeModel { + return model + } + } + + public companion object { + + /** + * @return a [TemplateGenerativeModelFutures] created around the provided + * [TemplateGenerativeModel] + */ + @JvmStatic + public fun from(model: TemplateGenerativeModel): TemplateGenerativeModelFutures = + FuturesImpl(model) + } +} diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/java/TemplateImagenModelFutures.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/java/TemplateImagenModelFutures.kt new file mode 100644 index 00000000000..5a38982fbda --- /dev/null +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/java/TemplateImagenModelFutures.kt @@ -0,0 +1,66 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.ai.java + +import androidx.concurrent.futures.SuspendToFutureAdapter +import com.google.common.util.concurrent.ListenableFuture +import com.google.firebase.ai.TemplateImagenModel +import com.google.firebase.ai.type.ImagenGenerationResponse +import com.google.firebase.ai.type.ImagenInlineImage +import com.google.firebase.ai.type.PublicPreviewAPI + +/** + * Wrapper class providing Java compatible methods for [TemplateImagenModel]. + * + * @see [TemplateImagenModel] + */ +@OptIn(PublicPreviewAPI::class) +public abstract class TemplateImagenModelFutures internal constructor() { + + /** + * Generates an image, returning the result directly to the caller. + * + * @param templateId The ID of server prompt template. + * @param inputs the inputs needed to fill in the prompt + */ + public abstract fun generateImages( + templateId: String, + inputs: Map + ): ListenableFuture> + + /** Returns the [TemplateImagenModel] object wrapped by this object. */ + public abstract fun getImageModel(): TemplateImagenModel + + private class FuturesImpl(private val model: TemplateImagenModel) : TemplateImagenModelFutures() { + override fun generateImages( + templateId: String, + inputs: Map + ): ListenableFuture> { + return SuspendToFutureAdapter.launchFuture { model.generateImages(templateId, inputs) } + } + + override fun getImageModel(): TemplateImagenModel { + return model + } + } + public companion object { + + /** @return a [TemplateImagenModelFutures] created around the provided [TemplateImagenModel] */ + @JvmStatic + public fun from(model: TemplateImagenModel): TemplateImagenModelFutures = FuturesImpl(model) + } +} diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt index 350d46e9063..d7450df3f3b 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Content.kt @@ -87,6 +87,10 @@ constructor(public val role: String? = "user", public val parts: List) { @OptIn(ExperimentalSerializationApi::class) internal fun toInternal() = Internal(this.role ?: "user", this.parts.map { it.toInternal() }) + @OptIn(ExperimentalSerializationApi::class) + internal fun toTemplateInternal() = + Internal(this.role ?: "user", this.parts.map { it.toInternal(true) }) + @ExperimentalSerializationApi @Serializable internal data class Internal( diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt index 6b578b8e46e..8e1702d5cd2 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Part.kt @@ -19,7 +19,6 @@ package com.google.firebase.ai.type import android.graphics.Bitmap import android.graphics.BitmapFactory import android.util.Log -import com.google.firebase.ai.type.ImagenImageFormat.Internal import java.io.ByteArrayOutputStream import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.SerialName @@ -337,13 +336,14 @@ internal object PartSerializer : } } -internal fun Part.toInternal(): InternalPart { +internal fun Part.toInternal(ignoreThoughtFlag: Boolean = false): InternalPart { + val thought = if (ignoreThoughtFlag) null else isThought return when (this) { - is TextPart -> TextPart.Internal(text, isThought, thoughtSignature) + is TextPart -> TextPart.Internal(text, thought, thoughtSignature) is ImagePart -> InlineDataPart.Internal( InlineData.Internal("image/jpeg", encodeBitmapToBase64Jpeg(image)), - isThought, + thought, thoughtSignature ) is InlineDataPart -> @@ -352,37 +352,37 @@ internal fun Part.toInternal(): InternalPart { mimeType, android.util.Base64.encodeToString(inlineData, BASE_64_FLAGS) ), - isThought, + thought, thoughtSignature ) is FunctionCallPart -> FunctionCallPart.Internal( FunctionCallPart.Internal.FunctionCall(name, args, id), - isThought, + thought, thoughtSignature ) is FunctionResponsePart -> FunctionResponsePart.Internal( FunctionResponsePart.Internal.FunctionResponse(name, response, id), - isThought, + thought, thoughtSignature ) is FileDataPart -> FileDataPart.Internal( FileDataPart.Internal.FileData(mimeType = mimeType, fileUri = uri), - isThought, + thought, thoughtSignature ) is ExecutableCodePart -> ExecutableCodePart.Internal( ExecutableCodePart.Internal.ExecutableCode(language, code), - isThought, + thought, thoughtSignature ) is CodeExecutionResultPart -> CodeExecutionResultPart.Internal( CodeExecutionResultPart.Internal.CodeExecutionResult(outcome, output), - isThought, + thought, thoughtSignature ) else -> diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt b/firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt index 476d68261d2..215b1eca9eb 100644 --- a/firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt +++ b/firebase-ai/src/test/java/com/google/firebase/ai/SerializationTests.kt @@ -16,6 +16,8 @@ package com.google.firebase.ai +import com.google.firebase.ai.common.TemplateGenerateContentRequest +import com.google.firebase.ai.common.TemplateGenerateImageRequest import com.google.firebase.ai.common.util.descriptorToJson import com.google.firebase.ai.type.Candidate import com.google.firebase.ai.type.CountTokensResponse @@ -512,6 +514,56 @@ internal class SerializationTests { expectedJsonAsString shouldEqualJson actualJson.toString() } + @Test + fun `test template request serialization as Json`() { + val expectedJsonAsString = + """ + { + "id": "TemplateGenerateContentRequest", + "type": "object", + "properties": { + "inputs": { + "type": "object", + "additionalProperties": { + "${"$"}ref": "JsonElement" + } + }, + "history": { + "type": "array", + "items": { + "${"$"}ref": "Content" + } + } + } + } + """ + .trimIndent() + val actualJson = descriptorToJson(TemplateGenerateContentRequest.serializer().descriptor) + expectedJsonAsString shouldEqualJson actualJson.toString() + } + + @Test + fun `test template imagen request serialization as Json`() { + val expectedJsonAsString = + """ + { + "id": "TemplateGenerateImageRequest", + "type": "object", + "properties": { + "inputs": { + "type": "object", + "additionalProperties": { + "${"$"}ref": "JsonElement" + } + } + } + } + """ + .trimIndent() + val actualJson = descriptorToJson(TemplateGenerateImageRequest.serializer().descriptor) + expectedJsonAsString shouldEqualJson actualJson.toString() + } + @Test fun `test GoogleSearch serialization as Json`() { val expectedJsonAsString =