Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ package com.google.firebase.vertexai.common
import android.util.Log
import com.google.firebase.Firebase
import com.google.firebase.options
import com.google.firebase.vertexai.common.server.FinishReason
import com.google.firebase.vertexai.common.server.GRpcError
import com.google.firebase.vertexai.common.server.GRpcErrorDetails
import com.google.firebase.vertexai.common.util.decodeToFlow
import com.google.firebase.vertexai.common.util.fullModelName
import com.google.firebase.vertexai.type.CountTokensResponse
import com.google.firebase.vertexai.type.FinishReason
import com.google.firebase.vertexai.type.GRpcErrorResponse
import com.google.firebase.vertexai.type.GenerateContentResponse
import com.google.firebase.vertexai.type.RequestOptions
import com.google.firebase.vertexai.type.Response
import io.ktor.client.HttpClient
import io.ktor.client.call.body
import io.ktor.client.engine.HttpClientEngine
Expand Down Expand Up @@ -106,31 +108,33 @@ internal constructor(
install(ContentNegotiation) { json(JSON) }
}

suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse.Internal =
try {
client
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:generateContent") {
applyCommonConfiguration(request)
applyHeaderProvider()
}
.also { validateResponse(it) }
.body<GenerateContentResponse>()
.body<GenerateContentResponse.Internal>()
.validate()
} catch (e: Throwable) {
throw FirebaseCommonAIException.from(e)
}

fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> =
fun generateContentStream(
request: GenerateContentRequest
): Flow<GenerateContentResponse.Internal> =
client
.postStream<GenerateContentResponse>(
.postStream<GenerateContentResponse.Internal>(
"${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse"
) {
applyCommonConfiguration(request)
}
.map { it.validate() }
.catch { throw FirebaseCommonAIException.from(it) }

suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
suspend fun countTokens(request: CountTokensRequest): CountTokensResponse.Internal =
try {
client
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:countTokens") {
Expand Down Expand Up @@ -275,19 +279,21 @@ private suspend fun validateResponse(response: HttpResponse) {
throw ServerException(message)
}

private fun getServiceDisabledErrorDetailsOrNull(error: GRpcError): GRpcErrorDetails? {
private fun getServiceDisabledErrorDetailsOrNull(
error: GRpcErrorResponse.GRpcError
): GRpcErrorResponse.GRpcError.GRpcErrorDetails? {
return error.details?.firstOrNull {
it.reason == "SERVICE_DISABLED" && it.domain == "googleapis.com"
}
}

private fun GenerateContentResponse.validate() = apply {
private fun GenerateContentResponse.Internal.validate() = apply {
if ((candidates?.isEmpty() != false) && promptFeedback == null) {
throw SerializationException("Error deserializing response, found no valid fields")
}
promptFeedback?.blockReason?.let { throw PromptBlockedException(this) }
candidates
?.mapNotNull { it.finishReason }
?.firstOrNull { it != FinishReason.STOP }
?.firstOrNull { it != FinishReason.Internal.STOP }
?.let { throw ResponseStoppedException(this) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.firebase.vertexai.common

import com.google.firebase.vertexai.type.GenerateContentResponse
import io.ktor.serialization.JsonConvertException
import kotlinx.coroutines.TimeoutCancellationException

Expand Down Expand Up @@ -66,7 +67,7 @@ internal class InvalidAPIKeyException(message: String, cause: Throwable? = null)
* @property response the full server response for the request.
*/
internal class PromptBlockedException(
val response: GenerateContentResponse,
val response: GenerateContentResponse.Internal,
cause: Throwable? = null
) :
FirebaseCommonAIException(
Expand Down Expand Up @@ -98,7 +99,7 @@ internal class InvalidStateException(message: String, cause: Throwable? = null)
* @property response the full server response for the request
*/
internal class ResponseStoppedException(
val response: GenerateContentResponse,
val response: GenerateContentResponse.Internal,
cause: Throwable? = null
) :
FirebaseCommonAIException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

package com.google.firebase.vertexai.common

import com.google.firebase.vertexai.common.client.GenerationConfig
import com.google.firebase.vertexai.common.client.Tool
import com.google.firebase.vertexai.common.client.ToolConfig
import com.google.firebase.vertexai.common.shared.Content
import com.google.firebase.vertexai.common.shared.SafetySetting
import com.google.firebase.vertexai.common.util.fullModelName
import com.google.firebase.vertexai.type.Content
import com.google.firebase.vertexai.type.GenerationConfig
import com.google.firebase.vertexai.type.SafetySetting
import com.google.firebase.vertexai.type.Tool
import com.google.firebase.vertexai.type.ToolConfig
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

Expand All @@ -30,21 +30,21 @@ internal sealed interface Request
@Serializable
internal data class GenerateContentRequest(
val model: String? = null,
val contents: List<Content>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig? = null,
val tools: List<Tool>? = null,
@SerialName("tool_config") var toolConfig: ToolConfig? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
val contents: List<Content.Internal>,
@SerialName("safety_settings") val safetySettings: List<SafetySetting.Internal>? = null,
@SerialName("generation_config") val generationConfig: GenerationConfig.Internal? = null,
val tools: List<Tool.Internal>? = null,
@SerialName("tool_config") var toolConfig: ToolConfig.Internal? = null,
@SerialName("system_instruction") val systemInstruction: Content.Internal? = null,
) : Request

@Serializable
internal data class CountTokensRequest(
val generateContentRequest: GenerateContentRequest? = null,
val model: String? = null,
val contents: List<Content>? = null,
val tools: List<Tool>? = null,
@SerialName("system_instruction") val systemInstruction: Content? = null,
val contents: List<Content.Internal>? = null,
val tools: List<Tool.Internal>? = null,
@SerialName("system_instruction") val systemInstruction: Content.Internal? = null,
) : Request {
companion object {
fun forGenAI(generateContentRequest: GenerateContentRequest) =
Expand Down

This file was deleted.

This file was deleted.

Loading
Loading