Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .changes/generativeai/decision-amount-amusement-bite.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type":"PATCH","changes":["Require at least one argument for functions that take vararg"]}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
prompt.assertComesFromUser()
attemptLock()
try {
val response = model.generateContent(*history.toTypedArray(), prompt)
val response = model.generateContent(prompt, *history.toTypedArray())
history.add(prompt)
history.add(response.candidates.first().content)
return response
Expand Down Expand Up @@ -100,7 +100,7 @@ class Chat(private val model: GenerativeModel, val history: MutableList<Content>
prompt.assertComesFromUser()
attemptLock()

val flow = model.generateContentStream(*history.toTypedArray(), prompt)
val flow = model.generateContentStream(prompt, *history.toTypedArray())
val bitmaps = LinkedList<Bitmap>()
val blobs = LinkedList<BlobPart>()
val text = StringBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ internal constructor(
* @return A [GenerateContentResponse] after some delay. Function should be called within a
* suspend context to properly manage concurrency.
*/
suspend fun generateContent(vararg prompt: Content): GenerateContentResponse =
suspend fun generateContent(prompt: Content, vararg prompts: Content): GenerateContentResponse =
try {
controller.generateContent(constructRequest(*prompt)).toPublic().validate()
controller.generateContent(constructRequest(prompt, *prompts)).toPublic().validate()
} catch (e: Throwable) {
throw GoogleGenerativeAIException.from(e)
}
Expand All @@ -121,9 +121,12 @@ internal constructor(
* @param prompt A group of [Content]s to send to the model.
* @return A [Flow] which will emit responses as they are returned from the model.
*/
fun generateContentStream(vararg prompt: Content): Flow<GenerateContentResponse> =
fun generateContentStream(
prompt: Content,
vararg prompts: Content
): Flow<GenerateContentResponse> =
controller
.generateContentStream(constructRequest(*prompt))
.generateContentStream(constructRequest(prompt, *prompts))
.catch { throw GoogleGenerativeAIException.from(it) }
.map { it.toPublic().validate() }

Expand Down Expand Up @@ -174,8 +177,8 @@ internal constructor(
* @param prompt A group of [Content]s to count tokens of.
* @return A [CountTokensResponse] containing the number of tokens in the prompt.
*/
suspend fun countTokens(vararg prompt: Content): CountTokensResponse {
return controller.countTokens(constructCountTokensRequest(*prompt)).toPublic()
suspend fun countTokens(prompt: Content, vararg prompts: Content): CountTokensResponse {
return controller.countTokens(constructCountTokensRequest(prompt, *prompts)).toPublic()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,30 @@ abstract class GenerativeModelFutures internal constructor() {
*
* @param prompt A group of [Content]s to send to the model.
*/
abstract fun generateContent(vararg prompt: Content): ListenableFuture<GenerateContentResponse>
abstract fun generateContent(
prompt: Content,
vararg prompts: Content
): ListenableFuture<GenerateContentResponse>

/**
* Generates a streaming response from the backend with the provided [Content]s.
*
* @param prompt A group of [Content]s to send to the model.
*/
abstract fun generateContentStream(vararg prompt: Content): Publisher<GenerateContentResponse>
abstract fun generateContentStream(
prompt: Content,
vararg prompts: Content
): Publisher<GenerateContentResponse>

/**
* Counts the number of tokens used in a prompt.
*
* @param prompt A group of [Content]s to count tokens of.
*/
abstract fun countTokens(vararg prompt: Content): ListenableFuture<CountTokensResponse>
abstract fun countTokens(
prompt: Content,
vararg prompts: Content
): ListenableFuture<CountTokensResponse>

/** Creates a chat instance which internally tracks the ongoing conversation with the model */
abstract fun startChat(): ChatFutures
Expand All @@ -69,15 +78,22 @@ abstract class GenerativeModelFutures internal constructor() {

private class FuturesImpl(private val model: GenerativeModel) : GenerativeModelFutures() {
override fun generateContent(
vararg prompt: Content
prompt: Content,
vararg prompts: Content
): ListenableFuture<GenerateContentResponse> =
SuspendToFutureAdapter.launchFuture { model.generateContent(*prompt) }

override fun generateContentStream(vararg prompt: Content): Publisher<GenerateContentResponse> =
model.generateContentStream(*prompt).asPublisher()

override fun countTokens(vararg prompt: Content): ListenableFuture<CountTokensResponse> =
SuspendToFutureAdapter.launchFuture { model.countTokens(*prompt) }
SuspendToFutureAdapter.launchFuture { model.generateContent(prompt, *prompts) }

override fun generateContentStream(
prompt: Content,
vararg prompts: Content
): Publisher<GenerateContentResponse> =
model.generateContentStream(prompt, *prompts).asPublisher()

override fun countTokens(
prompt: Content,
vararg prompts: Content
): ListenableFuture<CountTokensResponse> =
SuspendToFutureAdapter.launchFuture { model.countTokens(prompt, *prompts) }

override fun startChat(): ChatFutures = startChat(emptyList())

Expand Down