diff --git a/firebase-ai/CHANGELOG.md b/firebase-ai/CHANGELOG.md index 1fbb39c786a..59f410ef480 100644 --- a/firebase-ai/CHANGELOG.md +++ b/firebase-ai/CHANGELOG.md @@ -3,6 +3,10 @@ 2.5 series models. (#6990) * [feature] **Breaking Change**: Add support for Grounding with Google Search (#7042). * **Action Required:** Update all references of `groundingAttributions`, `webSearchQueries`, `retrievalQueries` in `GroundingMetadata` to be non-optional. +* [changed] require at least one argument for `generateContent()`, `generateContentStream()` and + `countTokens()`. +* [feature] Added new overloads for `generateContent()`, `generateContentStream()` and + `countTokens()` that take a `List` parameter. # 16.2.0 * [changed] Deprecate the `totalBillableCharacters` field (only usable with pre-2.0 models). (#7042) @@ -34,3 +38,4 @@ Note: This feature is in Public Preview, which means that it is not subject to any SLA or deprecation policy and could change in backwards-incompatible ways. + diff --git a/firebase-ai/api.txt b/firebase-ai/api.txt index 18b306482a1..70b35587515 100644 --- a/firebase-ai/api.txt +++ b/firebase-ai/api.txt @@ -53,14 +53,17 @@ package com.google.firebase.ai { public final class GenerativeModel { method public suspend Object? countTokens(android.graphics.Bitmap prompt, kotlin.coroutines.Continuation); - method public suspend Object? countTokens(com.google.firebase.ai.type.Content[] prompt, kotlin.coroutines.Continuation); + method public suspend Object? countTokens(com.google.firebase.ai.type.Content prompt, com.google.firebase.ai.type.Content[] prompts, kotlin.coroutines.Continuation); method public suspend Object? countTokens(String prompt, kotlin.coroutines.Continuation); + method public suspend Object? countTokens(java.util.List prompt, kotlin.coroutines.Continuation); method public suspend Object? generateContent(android.graphics.Bitmap prompt, kotlin.coroutines.Continuation); - method public suspend Object? generateContent(com.google.firebase.ai.type.Content[] prompt, kotlin.coroutines.Continuation); + method public suspend Object? generateContent(com.google.firebase.ai.type.Content prompt, com.google.firebase.ai.type.Content[] prompts, kotlin.coroutines.Continuation); method public suspend Object? generateContent(String prompt, kotlin.coroutines.Continuation); + method public suspend Object? generateContent(java.util.List prompt, kotlin.coroutines.Continuation); method public kotlinx.coroutines.flow.Flow generateContentStream(android.graphics.Bitmap prompt); - method public kotlinx.coroutines.flow.Flow generateContentStream(com.google.firebase.ai.type.Content... prompt); + method public kotlinx.coroutines.flow.Flow generateContentStream(com.google.firebase.ai.type.Content prompt, com.google.firebase.ai.type.Content... prompts); method public kotlinx.coroutines.flow.Flow generateContentStream(String prompt); + method public kotlinx.coroutines.flow.Flow generateContentStream(java.util.List prompt); method public com.google.firebase.ai.Chat startChat(java.util.List history = emptyList()); } @@ -89,10 +92,10 @@ package com.google.firebase.ai.java { } public abstract class GenerativeModelFutures { - method public abstract com.google.common.util.concurrent.ListenableFuture countTokens(com.google.firebase.ai.type.Content... prompt); + method public abstract com.google.common.util.concurrent.ListenableFuture countTokens(com.google.firebase.ai.type.Content prompt, com.google.firebase.ai.type.Content... prompts); method public static final com.google.firebase.ai.java.GenerativeModelFutures from(com.google.firebase.ai.GenerativeModel model); - method public abstract com.google.common.util.concurrent.ListenableFuture generateContent(com.google.firebase.ai.type.Content... prompt); - method public abstract org.reactivestreams.Publisher generateContentStream(com.google.firebase.ai.type.Content... prompt); + method public abstract com.google.common.util.concurrent.ListenableFuture generateContent(com.google.firebase.ai.type.Content prompt, com.google.firebase.ai.type.Content... prompts); + method public abstract org.reactivestreams.Publisher generateContentStream(com.google.firebase.ai.type.Content prompt, com.google.firebase.ai.type.Content... prompts); method public abstract com.google.firebase.ai.GenerativeModel getGenerativeModel(); method public abstract com.google.firebase.ai.java.ChatFutures startChat(); method public abstract com.google.firebase.ai.java.ChatFutures startChat(java.util.List history); diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/Chat.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/Chat.kt index 13599fb1c9a..73d304d3885 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/Chat.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/Chat.kt @@ -66,7 +66,8 @@ public class Chat( prompt.assertComesFromUser() attemptLock() try { - val response = model.generateContent(*history.toTypedArray(), prompt) + val fullPrompt = history + prompt + val response = model.generateContent(fullPrompt.first(), *fullPrompt.drop(1).toTypedArray()) history.add(prompt) history.add(response.candidates.first().content) return response @@ -127,7 +128,8 @@ public class Chat( prompt.assertComesFromUser() attemptLock() - val flow = model.generateContentStream(*history.toTypedArray(), prompt) + val fullPrompt = history + prompt + val flow = model.generateContentStream(fullPrompt.first(), *fullPrompt.drop(1).toTypedArray()) val bitmaps = LinkedList() val inlineDataParts = LinkedList() val text = StringBuilder() diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/GenerativeModel.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/GenerativeModel.kt index 1b36998f970..286f61fdb8a 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/GenerativeModel.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/GenerativeModel.kt @@ -100,13 +100,48 @@ internal constructor( * @throws [FirebaseAIException] if the request failed. * @see [FirebaseAIException] for types of errors. */ - public suspend fun generateContent(vararg prompt: Content): GenerateContentResponse = + public suspend fun generateContent( + prompt: Content, + vararg prompts: Content + ): GenerateContentResponse = try { - controller.generateContent(constructRequest(*prompt)).toPublic().validate() + controller.generateContent(constructRequest(prompt, *prompts)).toPublic().validate() } catch (e: Throwable) { throw FirebaseAIException.from(e) } + /** + * Generates new content from the input [Content] given to the model as a prompt. + * + * @param prompt The input(s) given to the model as a prompt. + * @return The content generated by the model. + * @throws [FirebaseAIException] if the request failed. + * @see [FirebaseAIException] for types of errors. + */ + public suspend fun generateContent(prompt: List): GenerateContentResponse = + try { + controller.generateContent(constructRequest(prompt)).toPublic().validate() + } catch (e: Throwable) { + throw FirebaseAIException.from(e) + } + + /** + * Generates new content as a stream from the input [Content] given to the model as a prompt. + * + * @param prompt The input(s) given to the model as a prompt. + * @return A [Flow] which will emit responses as they are returned by the model. + * @throws [FirebaseAIException] if the request failed. + * @see [FirebaseAIException] for types of errors. + */ + public fun generateContentStream( + prompt: Content, + vararg prompts: Content + ): Flow = + controller + .generateContentStream(constructRequest(prompt, *prompts)) + .catch { throw FirebaseAIException.from(it) } + .map { it.toPublic().validate() } + /** * Generates new content as a stream from the input [Content] given to the model as a prompt. * @@ -115,9 +150,9 @@ internal constructor( * @throws [FirebaseAIException] if the request failed. * @see [FirebaseAIException] for types of errors. */ - public fun generateContentStream(vararg prompt: Content): Flow = + public fun generateContentStream(prompt: List): Flow = controller - .generateContentStream(constructRequest(*prompt)) + .generateContentStream(constructRequest(prompt)) .catch { throw FirebaseAIException.from(it) } .map { it.toPublic().validate() } @@ -177,9 +212,25 @@ internal constructor( * @throws [FirebaseAIException] if the request failed. * @see [FirebaseAIException] for types of errors. */ - public suspend fun countTokens(vararg prompt: Content): CountTokensResponse { + public suspend fun countTokens(prompt: Content, vararg prompts: Content): CountTokensResponse { + try { + return controller.countTokens(constructCountTokensRequest(prompt, *prompts)).toPublic() + } catch (e: Throwable) { + throw FirebaseAIException.from(e) + } + } + + /** + * Counts the number of tokens in a prompt using the model's tokenizer. + * + * @param prompt The input(s) given to the model as a prompt. + * @return The [CountTokensResponse] of running the model's tokenizer on the input. + * @throws [FirebaseAIException] if the request failed. + * @see [FirebaseAIException] for types of errors. + */ + public suspend fun countTokens(prompt: List): CountTokensResponse { try { - return controller.countTokens(constructCountTokensRequest(*prompt)).toPublic() + return controller.countTokens(constructCountTokensRequest(*prompt.toTypedArray())).toPublic() } catch (e: Throwable) { throw FirebaseAIException.from(e) } @@ -232,6 +283,8 @@ internal constructor( systemInstruction?.copy(role = "system")?.toInternal(), ) + private fun constructRequest(prompt: List) = constructRequest(*prompt.toTypedArray()) + private fun constructCountTokensRequest(vararg prompt: Content) = when (generativeBackend.backend) { GenerativeBackendEnum.GOOGLE_AI -> CountTokensRequest.forGoogleAI(constructRequest(*prompt)) diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/java/GenerativeModelFutures.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/java/GenerativeModelFutures.kt index 57a531c1cd8..51a90135e12 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/java/GenerativeModelFutures.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/java/GenerativeModelFutures.kt @@ -42,7 +42,8 @@ public abstract class GenerativeModelFutures internal constructor() { * @throws [FirebaseAIException] if the request failed. */ public abstract fun generateContent( - vararg prompt: Content + prompt: Content, + vararg prompts: Content ): ListenableFuture /** @@ -53,7 +54,8 @@ public abstract class GenerativeModelFutures internal constructor() { * @throws [FirebaseAIException] if the request failed. */ public abstract fun generateContentStream( - vararg prompt: Content + prompt: Content, + vararg prompts: Content ): Publisher /** @@ -63,7 +65,10 @@ public abstract class GenerativeModelFutures internal constructor() { * @return The [CountTokensResponse] of running the model's tokenizer on the input. * @throws [FirebaseAIException] if the request failed. */ - public abstract fun countTokens(vararg prompt: Content): ListenableFuture + public abstract fun countTokens( + prompt: Content, + vararg prompts: Content + ): ListenableFuture /** * Creates a [ChatFutures] instance which internally tracks the ongoing conversation with the @@ -83,15 +88,22 @@ public abstract class GenerativeModelFutures internal constructor() { private class FuturesImpl(private val model: GenerativeModel) : GenerativeModelFutures() { override fun generateContent( - vararg prompt: Content + prompt: Content, + vararg prompts: Content ): ListenableFuture = - SuspendToFutureAdapter.launchFuture { model.generateContent(*prompt) } - - override fun generateContentStream(vararg prompt: Content): Publisher = - model.generateContentStream(*prompt).asPublisher() - - override fun countTokens(vararg prompt: Content): ListenableFuture = - SuspendToFutureAdapter.launchFuture { model.countTokens(*prompt) } + SuspendToFutureAdapter.launchFuture { model.generateContent(prompt, *prompts) } + + override fun generateContentStream( + prompt: Content, + vararg prompts: Content + ): Publisher = + model.generateContentStream(prompt, *prompts).asPublisher() + + override fun countTokens( + prompt: Content, + vararg prompts: Content + ): ListenableFuture = + SuspendToFutureAdapter.launchFuture { model.countTokens(prompt, *prompts) } override fun startChat(): ChatFutures = startChat(emptyList())