diff --git a/firebase-vertexai/consumer-rules.pro b/firebase-vertexai/consumer-rules.pro index 7947f53cd58..f328794a748 100644 --- a/firebase-vertexai/consumer-rules.pro +++ b/firebase-vertexai/consumer-rules.pro @@ -20,4 +20,5 @@ # hide the original source file name. #-renamesourcefileattribute SourceFile +-keep class com.google.firebase.vertexai.type.** { *; } -keep class com.google.firebase.vertexai.common.** { *; } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt index 45c6aa9bce4..c2b9fcfd2f9 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt @@ -24,8 +24,6 @@ import com.google.firebase.vertexai.common.APIController import com.google.firebase.vertexai.common.CountTokensRequest import com.google.firebase.vertexai.common.GenerateContentRequest import com.google.firebase.vertexai.common.HeaderProvider -import com.google.firebase.vertexai.internal.util.toInternal -import com.google.firebase.vertexai.internal.util.toPublic import com.google.firebase.vertexai.type.Content import com.google.firebase.vertexai.type.CountTokensResponse import com.google.firebase.vertexai.type.FinishReason diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt index cbcc0c69e97..286b8829241 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt @@ -19,12 +19,14 @@ package com.google.firebase.vertexai.common import android.util.Log import com.google.firebase.Firebase import com.google.firebase.options -import com.google.firebase.vertexai.common.server.FinishReason -import com.google.firebase.vertexai.common.server.GRpcError -import com.google.firebase.vertexai.common.server.GRpcErrorDetails import com.google.firebase.vertexai.common.util.decodeToFlow import com.google.firebase.vertexai.common.util.fullModelName +import com.google.firebase.vertexai.type.CountTokensResponse +import com.google.firebase.vertexai.type.FinishReason +import com.google.firebase.vertexai.type.GRpcErrorResponse +import com.google.firebase.vertexai.type.GenerateContentResponse import com.google.firebase.vertexai.type.RequestOptions +import com.google.firebase.vertexai.type.Response import io.ktor.client.HttpClient import io.ktor.client.call.body import io.ktor.client.engine.HttpClientEngine @@ -106,7 +108,7 @@ internal constructor( install(ContentNegotiation) { json(JSON) } } - suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse = + suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse.Internal = try { client .post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:generateContent") { @@ -114,15 +116,17 @@ internal constructor( applyHeaderProvider() } .also { validateResponse(it) } - .body() + .body() .validate() } catch (e: Throwable) { throw FirebaseCommonAIException.from(e) } - fun generateContentStream(request: GenerateContentRequest): Flow = + fun generateContentStream( + request: GenerateContentRequest + ): Flow = client - .postStream( + .postStream( "${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse" ) { applyCommonConfiguration(request) @@ -130,7 +134,7 @@ internal constructor( .map { it.validate() } .catch { throw FirebaseCommonAIException.from(it) } - suspend fun countTokens(request: CountTokensRequest): CountTokensResponse = + suspend fun countTokens(request: CountTokensRequest): CountTokensResponse.Internal = try { client .post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:countTokens") { @@ -275,19 +279,21 @@ private suspend fun validateResponse(response: HttpResponse) { throw ServerException(message) } -private fun getServiceDisabledErrorDetailsOrNull(error: GRpcError): GRpcErrorDetails? { +private fun getServiceDisabledErrorDetailsOrNull( + error: GRpcErrorResponse.GRpcError +): GRpcErrorResponse.GRpcError.GRpcErrorDetails? { return error.details?.firstOrNull { it.reason == "SERVICE_DISABLED" && it.domain == "googleapis.com" } } -private fun GenerateContentResponse.validate() = apply { +private fun GenerateContentResponse.Internal.validate() = apply { if ((candidates?.isEmpty() != false) && promptFeedback == null) { throw SerializationException("Error deserializing response, found no valid fields") } promptFeedback?.blockReason?.let { throw PromptBlockedException(this) } candidates ?.mapNotNull { it.finishReason } - ?.firstOrNull { it != FinishReason.STOP } + ?.firstOrNull { it != FinishReason.Internal.STOP } ?.let { throw ResponseStoppedException(this) } } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt index 41954b13497..7567c384618 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt @@ -16,6 +16,7 @@ package com.google.firebase.vertexai.common +import com.google.firebase.vertexai.type.GenerateContentResponse import io.ktor.serialization.JsonConvertException import kotlinx.coroutines.TimeoutCancellationException @@ -66,7 +67,7 @@ internal class InvalidAPIKeyException(message: String, cause: Throwable? = null) * @property response the full server response for the request. */ internal class PromptBlockedException( - val response: GenerateContentResponse, + val response: GenerateContentResponse.Internal, cause: Throwable? = null ) : FirebaseCommonAIException( @@ -98,7 +99,7 @@ internal class InvalidStateException(message: String, cause: Throwable? = null) * @property response the full server response for the request */ internal class ResponseStoppedException( - val response: GenerateContentResponse, + val response: GenerateContentResponse.Internal, cause: Throwable? = null ) : FirebaseCommonAIException( @@ -125,3 +126,18 @@ internal class ServiceDisabledException(message: String, cause: Throwable? = nul /** Catch all case for exceptions not explicitly expected. */ internal class UnknownException(message: String, cause: Throwable? = null) : FirebaseCommonAIException(message, cause) + +internal fun makeMissingCaseException( + source: String, + ordinal: Int +): com.google.firebase.vertexai.type.SerializationException { + return com.google.firebase.vertexai.type.SerializationException( + """ + |Missing case for a $source: $ordinal + |This error indicates that one of the `toInternal` conversions needs updating. + |If you're a developer seeing this exception, please file an issue on our GitHub repo: + |https://github.com/firebase/firebase-android-sdk + """ + .trimMargin() + ) +} diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt index 39adea5629c..040a38e0a0b 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt @@ -16,12 +16,12 @@ package com.google.firebase.vertexai.common -import com.google.firebase.vertexai.common.client.GenerationConfig -import com.google.firebase.vertexai.common.client.Tool -import com.google.firebase.vertexai.common.client.ToolConfig -import com.google.firebase.vertexai.common.shared.Content -import com.google.firebase.vertexai.common.shared.SafetySetting import com.google.firebase.vertexai.common.util.fullModelName +import com.google.firebase.vertexai.type.Content +import com.google.firebase.vertexai.type.GenerationConfig +import com.google.firebase.vertexai.type.SafetySetting +import com.google.firebase.vertexai.type.Tool +import com.google.firebase.vertexai.type.ToolConfig import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -30,21 +30,21 @@ internal sealed interface Request @Serializable internal data class GenerateContentRequest( val model: String? = null, - val contents: List, - @SerialName("safety_settings") val safetySettings: List? = null, - @SerialName("generation_config") val generationConfig: GenerationConfig? = null, - val tools: List? = null, - @SerialName("tool_config") var toolConfig: ToolConfig? = null, - @SerialName("system_instruction") val systemInstruction: Content? = null, + val contents: List, + @SerialName("safety_settings") val safetySettings: List? = null, + @SerialName("generation_config") val generationConfig: GenerationConfig.Internal? = null, + val tools: List? = null, + @SerialName("tool_config") var toolConfig: ToolConfig.Internal? = null, + @SerialName("system_instruction") val systemInstruction: Content.Internal? = null, ) : Request @Serializable internal data class CountTokensRequest( val generateContentRequest: GenerateContentRequest? = null, val model: String? = null, - val contents: List? = null, - val tools: List? = null, - @SerialName("system_instruction") val systemInstruction: Content? = null, + val contents: List? = null, + val tools: List? = null, + @SerialName("system_instruction") val systemInstruction: Content.Internal? = null, ) : Request { companion object { fun forGenAI(generateContentRequest: GenerateContentRequest) = diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Response.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Response.kt deleted file mode 100644 index d8182883442..00000000000 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Response.kt +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.firebase.vertexai.common - -import com.google.firebase.vertexai.common.server.Candidate -import com.google.firebase.vertexai.common.server.GRpcError -import com.google.firebase.vertexai.common.server.PromptFeedback -import kotlinx.serialization.Serializable - -internal sealed interface Response - -@Serializable -internal data class GenerateContentResponse( - val candidates: List? = null, - val promptFeedback: PromptFeedback? = null, - val usageMetadata: UsageMetadata? = null, -) : Response - -@Serializable -internal data class CountTokensResponse( - val totalTokens: Int, - val totalBillableCharacters: Int? = null -) : Response - -@Serializable internal data class GRpcErrorResponse(val error: GRpcError) : Response - -@Serializable -internal data class UsageMetadata( - val promptTokenCount: Int? = null, - val candidatesTokenCount: Int? = null, - val totalTokenCount: Int? = null, -) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/client/Types.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/client/Types.kt deleted file mode 100644 index b950aa3c5f2..00000000000 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/client/Types.kt +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.firebase.vertexai.common.client - -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable -import kotlinx.serialization.json.JsonObject - -@Serializable -internal data class GenerationConfig( - val temperature: Float?, - @SerialName("top_p") val topP: Float?, - @SerialName("top_k") val topK: Int?, - @SerialName("candidate_count") val candidateCount: Int?, - @SerialName("max_output_tokens") val maxOutputTokens: Int?, - @SerialName("stop_sequences") val stopSequences: List?, - @SerialName("response_mime_type") val responseMimeType: String? = null, - @SerialName("presence_penalty") val presencePenalty: Float? = null, - @SerialName("frequency_penalty") val frequencyPenalty: Float? = null, - @SerialName("response_schema") val responseSchema: Schema? = null, -) - -@Serializable -internal data class Tool( - val functionDeclarations: List? = null, - // This is a json object because it is not possible to make a data class with no parameters. - val codeExecution: JsonObject? = null, -) - -@Serializable -internal data class ToolConfig( - @SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig? -) - -@Serializable -internal data class FunctionCallingConfig( - val mode: Mode, - @SerialName("allowed_function_names") val allowedFunctionNames: List? = null -) { - @Serializable - enum class Mode { - @SerialName("MODE_UNSPECIFIED") UNSPECIFIED, - AUTO, - ANY, - NONE - } -} - -@Serializable -internal data class FunctionDeclaration( - val name: String, - val description: String, - val parameters: Schema -) - -@Serializable -internal data class Schema( - val type: String, - val description: String? = null, - val format: String? = null, - val nullable: Boolean? = false, - val enum: List? = null, - val properties: Map? = null, - val required: List? = null, - val items: Schema? = null, -) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/server/Types.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/server/Types.kt deleted file mode 100644 index 3749d534e47..00000000000 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/server/Types.kt +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.firebase.vertexai.common.server - -import com.google.firebase.vertexai.common.shared.Content -import com.google.firebase.vertexai.common.shared.HarmCategory -import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.KSerializer -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable -import kotlinx.serialization.json.JsonNames - -internal object BlockReasonSerializer : - KSerializer by FirstOrdinalSerializer(BlockReason::class) - -internal object HarmProbabilitySerializer : - KSerializer by FirstOrdinalSerializer(HarmProbability::class) - -internal object HarmSeveritySerializer : - KSerializer by FirstOrdinalSerializer(HarmSeverity::class) - -internal object FinishReasonSerializer : - KSerializer by FirstOrdinalSerializer(FinishReason::class) - -@Serializable -internal data class PromptFeedback( - val blockReason: BlockReason? = null, - val safetyRatings: List? = null, - val blockReasonMessage: String? = null, -) - -@Serializable(BlockReasonSerializer::class) -internal enum class BlockReason { - UNKNOWN, - @SerialName("BLOCKED_REASON_UNSPECIFIED") UNSPECIFIED, - SAFETY, - OTHER -} - -@Serializable -internal data class Candidate( - val content: Content? = null, - val finishReason: FinishReason? = null, - val safetyRatings: List? = null, - val citationMetadata: CitationMetadata? = null, - val groundingMetadata: GroundingMetadata? = null, -) - -@Serializable -internal data class CitationMetadata -@OptIn(ExperimentalSerializationApi::class) -internal constructor(@JsonNames("citations") val citationSources: List) - -@Serializable -internal data class CitationSources( - val title: String? = null, - val startIndex: Int = 0, - val endIndex: Int, - val uri: String? = null, - val license: String? = null, - val publicationDate: Date? = null, -) - -@Serializable -internal data class Date( - /** Year of the date. Must be between 1 and 9999, or 0 for no year. */ - val year: Int? = null, - /** 1-based index for month. Must be from 1 to 12, or 0 to specify a year without a month. */ - val month: Int? = null, - /** - * Day of a month. Must be from 1 to 31 and valid for the year and month, or 0 to specify a year - * by itself or a year and month where the day isn't significant. - */ - val day: Int? = null, -) - -@Serializable -internal data class SafetyRating( - val category: HarmCategory, - val probability: HarmProbability, - val blocked: Boolean? = null, // TODO(): any reason not to default to false? - val probabilityScore: Float? = null, - val severity: HarmSeverity? = null, - val severityScore: Float? = null, -) - -@Serializable -internal data class GroundingMetadata( - @SerialName("web_search_queries") val webSearchQueries: List?, - @SerialName("search_entry_point") val searchEntryPoint: SearchEntryPoint?, - @SerialName("retrieval_queries") val retrievalQueries: List?, - @SerialName("grounding_attribution") val groundingAttribution: List?, -) - -@Serializable -internal data class SearchEntryPoint( - @SerialName("rendered_content") val renderedContent: String?, - @SerialName("sdk_blob") val sdkBlob: String?, -) - -@Serializable -internal data class GroundingAttribution( - val segment: Segment, - @SerialName("confidence_score") val confidenceScore: Float?, -) - -@Serializable -internal data class Segment( - @SerialName("start_index") val startIndex: Int, - @SerialName("end_index") val endIndex: Int, -) - -@Serializable(HarmProbabilitySerializer::class) -internal enum class HarmProbability { - UNKNOWN, - @SerialName("HARM_PROBABILITY_UNSPECIFIED") UNSPECIFIED, - NEGLIGIBLE, - LOW, - MEDIUM, - HIGH -} - -@Serializable(HarmSeveritySerializer::class) -internal enum class HarmSeverity { - UNKNOWN, - @SerialName("HARM_SEVERITY_UNSPECIFIED") UNSPECIFIED, - @SerialName("HARM_SEVERITY_NEGLIGIBLE") NEGLIGIBLE, - @SerialName("HARM_SEVERITY_LOW") LOW, - @SerialName("HARM_SEVERITY_MEDIUM") MEDIUM, - @SerialName("HARM_SEVERITY_HIGH") HIGH -} - -@Serializable(FinishReasonSerializer::class) -internal enum class FinishReason { - UNKNOWN, - @SerialName("FINISH_REASON_UNSPECIFIED") UNSPECIFIED, - STOP, - MAX_TOKENS, - SAFETY, - RECITATION, - OTHER -} - -@Serializable -internal data class GRpcError( - val code: Int, - val message: String, - val details: List? = null -) - -@Serializable -internal data class GRpcErrorDetails( - val reason: String? = null, - val domain: String? = null, - val metadata: Map? = null -) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/shared/Types.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/shared/Types.kt deleted file mode 100644 index b32772e995c..00000000000 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/shared/Types.kt +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.firebase.vertexai.common.shared - -import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer -import kotlinx.serialization.DeserializationStrategy -import kotlinx.serialization.EncodeDefault -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.KSerializer -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable -import kotlinx.serialization.SerializationException -import kotlinx.serialization.json.JsonContentPolymorphicSerializer -import kotlinx.serialization.json.JsonElement -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.jsonObject - -internal object HarmCategorySerializer : - KSerializer by FirstOrdinalSerializer(HarmCategory::class) - -@Serializable(HarmCategorySerializer::class) -internal enum class HarmCategory { - UNKNOWN, - @SerialName("HARM_CATEGORY_HARASSMENT") HARASSMENT, - @SerialName("HARM_CATEGORY_HATE_SPEECH") HATE_SPEECH, - @SerialName("HARM_CATEGORY_SEXUALLY_EXPLICIT") SEXUALLY_EXPLICIT, - @SerialName("HARM_CATEGORY_DANGEROUS_CONTENT") DANGEROUS_CONTENT, - @SerialName("HARM_CATEGORY_CIVIC_INTEGRITY") CIVIC_INTEGRITY, -} - -internal typealias Base64 = String - -@ExperimentalSerializationApi -@Serializable -internal data class Content(@EncodeDefault val role: String? = "user", val parts: List) - -@Serializable(PartSerializer::class) internal sealed interface Part - -@Serializable internal data class TextPart(val text: String) : Part - -@Serializable -internal data class InlineDataPart(@SerialName("inline_data") val inlineData: InlineData) : Part - -@Serializable internal data class FunctionCallPart(val functionCall: FunctionCall) : Part - -@Serializable -internal data class FunctionResponsePart(val functionResponse: FunctionResponse) : Part - -@Serializable internal data class FunctionResponse(val name: String, val response: JsonObject) - -@Serializable -internal data class FunctionCall(val name: String, val args: Map? = null) - -@Serializable -internal data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part - -@Serializable -internal data class FileData( - @SerialName("mime_type") val mimeType: String, - @SerialName("file_uri") val fileUri: String, -) - -@Serializable -internal data class InlineData(@SerialName("mime_type") val mimeType: String, val data: Base64) - -@Serializable -internal data class SafetySetting( - val category: HarmCategory, - val threshold: HarmBlockThreshold, - val method: HarmBlockMethod? = null, -) - -@Serializable -internal enum class HarmBlockThreshold { - @SerialName("HARM_BLOCK_THRESHOLD_UNSPECIFIED") UNSPECIFIED, - BLOCK_LOW_AND_ABOVE, - BLOCK_MEDIUM_AND_ABOVE, - BLOCK_ONLY_HIGH, - BLOCK_NONE, -} - -@Serializable -internal enum class HarmBlockMethod { - @SerialName("HARM_BLOCK_METHOD_UNSPECIFIED") UNSPECIFIED, - SEVERITY, - PROBABILITY, -} - -internal object PartSerializer : JsonContentPolymorphicSerializer(Part::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - val jsonObject = element.jsonObject - return when { - "text" in jsonObject -> TextPart.serializer() - "functionCall" in jsonObject -> FunctionCallPart.serializer() - "functionResponse" in jsonObject -> FunctionResponsePart.serializer() - "inlineData" in jsonObject -> InlineDataPart.serializer() - "fileData" in jsonObject -> FileDataPart.serializer() - else -> throw SerializationException("Unknown Part type") - } - } -} diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt deleted file mode 100644 index ed0efc394c0..00000000000 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt +++ /dev/null @@ -1,377 +0,0 @@ -/* - * Copyright 2023 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.firebase.vertexai.internal.util - -import android.graphics.Bitmap -import android.graphics.BitmapFactory -import android.util.Base64 -import com.google.firebase.vertexai.common.client.Schema -import com.google.firebase.vertexai.common.shared.FileData -import com.google.firebase.vertexai.common.shared.FunctionCall -import com.google.firebase.vertexai.common.shared.FunctionCallPart -import com.google.firebase.vertexai.common.shared.FunctionResponse -import com.google.firebase.vertexai.common.shared.FunctionResponsePart -import com.google.firebase.vertexai.common.shared.InlineData -import com.google.firebase.vertexai.type.BlockReason -import com.google.firebase.vertexai.type.Candidate -import com.google.firebase.vertexai.type.Citation -import com.google.firebase.vertexai.type.CitationMetadata -import com.google.firebase.vertexai.type.Content -import com.google.firebase.vertexai.type.CountTokensResponse -import com.google.firebase.vertexai.type.FileDataPart -import com.google.firebase.vertexai.type.FinishReason -import com.google.firebase.vertexai.type.FunctionCallingConfig -import com.google.firebase.vertexai.type.FunctionDeclaration -import com.google.firebase.vertexai.type.GenerateContentResponse -import com.google.firebase.vertexai.type.GenerationConfig -import com.google.firebase.vertexai.type.HarmBlockMethod -import com.google.firebase.vertexai.type.HarmBlockThreshold -import com.google.firebase.vertexai.type.HarmCategory -import com.google.firebase.vertexai.type.HarmProbability -import com.google.firebase.vertexai.type.HarmSeverity -import com.google.firebase.vertexai.type.ImagePart -import com.google.firebase.vertexai.type.InlineDataPart -import com.google.firebase.vertexai.type.Part -import com.google.firebase.vertexai.type.PromptFeedback -import com.google.firebase.vertexai.type.SafetyRating -import com.google.firebase.vertexai.type.SafetySetting -import com.google.firebase.vertexai.type.SerializationException -import com.google.firebase.vertexai.type.TextPart -import com.google.firebase.vertexai.type.Tool -import com.google.firebase.vertexai.type.ToolConfig -import com.google.firebase.vertexai.type.UsageMetadata -import com.google.firebase.vertexai.type.content -import java.io.ByteArrayOutputStream -import java.util.Calendar -import kotlinx.serialization.json.Json -import kotlinx.serialization.json.JsonNull -import kotlinx.serialization.json.JsonObject -import org.json.JSONObject - -private const val BASE_64_FLAGS = Base64.NO_WRAP - -internal fun Content.toInternal() = - com.google.firebase.vertexai.common.shared.Content( - this.role ?: "user", - this.parts.map { it.toInternal() } - ) - -internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part { - return when (this) { - is TextPart -> com.google.firebase.vertexai.common.shared.TextPart(text) - is ImagePart -> - com.google.firebase.vertexai.common.shared.InlineDataPart( - InlineData("image/jpeg", encodeBitmapToBase64Png(image)) - ) - is InlineDataPart -> - com.google.firebase.vertexai.common.shared.InlineDataPart( - InlineData(mimeType, Base64.encodeToString(inlineData, BASE_64_FLAGS)) - ) - is com.google.firebase.vertexai.type.FunctionCallPart -> - FunctionCallPart(FunctionCall(name, args)) - is com.google.firebase.vertexai.type.FunctionResponsePart -> - FunctionResponsePart(FunctionResponse(name, response)) - is FileDataPart -> - com.google.firebase.vertexai.common.shared.FileDataPart( - FileData(mimeType = mimeType, fileUri = uri) - ) - else -> - throw SerializationException( - "The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet." - ) - } -} - -internal fun SafetySetting.toInternal() = - com.google.firebase.vertexai.common.shared.SafetySetting( - harmCategory.toInternal(), - threshold.toInternal(), - method?.toInternal() - ) - -internal fun makeMissingCaseException(source: String, ordinal: Int): SerializationException { - return SerializationException( - """ - |Missing case for a $source: $ordinal - |This error indicates that one of the `toInternal` conversions needs updating. - |If you're a developer seeing this exception, please file an issue on our GitHub repo: - |https://github.com/firebase/firebase-android-sdk - """ - .trimMargin() - ) -} - -internal fun GenerationConfig.toInternal() = - com.google.firebase.vertexai.common.client.GenerationConfig( - temperature = temperature, - topP = topP, - topK = topK, - candidateCount = candidateCount, - maxOutputTokens = maxOutputTokens, - stopSequences = stopSequences, - frequencyPenalty = frequencyPenalty, - presencePenalty = presencePenalty, - responseMimeType = responseMimeType, - responseSchema = responseSchema?.toInternal() - ) - -internal fun HarmCategory.toInternal() = - when (this) { - HarmCategory.HARASSMENT -> com.google.firebase.vertexai.common.shared.HarmCategory.HARASSMENT - HarmCategory.HATE_SPEECH -> com.google.firebase.vertexai.common.shared.HarmCategory.HATE_SPEECH - HarmCategory.SEXUALLY_EXPLICIT -> - com.google.firebase.vertexai.common.shared.HarmCategory.SEXUALLY_EXPLICIT - HarmCategory.DANGEROUS_CONTENT -> - com.google.firebase.vertexai.common.shared.HarmCategory.DANGEROUS_CONTENT - HarmCategory.CIVIC_INTEGRITY -> - com.google.firebase.vertexai.common.shared.HarmCategory.CIVIC_INTEGRITY - HarmCategory.UNKNOWN -> com.google.firebase.vertexai.common.shared.HarmCategory.UNKNOWN - else -> throw makeMissingCaseException("HarmCategory", ordinal) - } - -internal fun HarmBlockMethod.toInternal() = - when (this) { - HarmBlockMethod.SEVERITY -> com.google.firebase.vertexai.common.shared.HarmBlockMethod.SEVERITY - HarmBlockMethod.PROBABILITY -> - com.google.firebase.vertexai.common.shared.HarmBlockMethod.PROBABILITY - else -> throw makeMissingCaseException("HarmBlockMethod", ordinal) - } - -internal fun ToolConfig.toInternal() = - com.google.firebase.vertexai.common.client.ToolConfig( - functionCallingConfig?.let { - com.google.firebase.vertexai.common.client.FunctionCallingConfig( - when (it.mode) { - FunctionCallingConfig.Mode.ANY -> - com.google.firebase.vertexai.common.client.FunctionCallingConfig.Mode.ANY - FunctionCallingConfig.Mode.AUTO -> - com.google.firebase.vertexai.common.client.FunctionCallingConfig.Mode.AUTO - FunctionCallingConfig.Mode.NONE -> - com.google.firebase.vertexai.common.client.FunctionCallingConfig.Mode.NONE - }, - it.allowedFunctionNames - ) - } - ) - -internal fun HarmBlockThreshold.toInternal() = - when (this) { - HarmBlockThreshold.NONE -> - com.google.firebase.vertexai.common.shared.HarmBlockThreshold.BLOCK_NONE - HarmBlockThreshold.ONLY_HIGH -> - com.google.firebase.vertexai.common.shared.HarmBlockThreshold.BLOCK_ONLY_HIGH - HarmBlockThreshold.MEDIUM_AND_ABOVE -> - com.google.firebase.vertexai.common.shared.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE - HarmBlockThreshold.LOW_AND_ABOVE -> - com.google.firebase.vertexai.common.shared.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE - else -> throw makeMissingCaseException("HarmBlockThreshold", ordinal) - } - -internal fun Tool.toInternal() = - com.google.firebase.vertexai.common.client.Tool( - functionDeclarations?.map { it.toInternal() } ?: emptyList() - ) - -internal fun FunctionDeclaration.toInternal() = - com.google.firebase.vertexai.common.client.FunctionDeclaration(name, "", schema.toInternal()) - -internal fun com.google.firebase.vertexai.type.Schema.toInternal(): Schema = - Schema( - type, - description, - format, - nullable, - enum, - properties?.mapValues { it.value.toInternal() }, - required, - items?.toInternal(), - ) - -internal fun JSONObject.toInternal() = Json.decodeFromString(toString()) - -internal fun com.google.firebase.vertexai.common.server.Candidate.toPublic(): Candidate { - val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() - val citations = citationMetadata?.toPublic() - val finishReason = finishReason.toPublic() - - return Candidate( - this.content?.toPublic() ?: content("model") {}, - safetyRatings, - citations, - finishReason - ) -} - -internal fun com.google.firebase.vertexai.common.UsageMetadata.toPublic(): UsageMetadata = - UsageMetadata(promptTokenCount ?: 0, candidatesTokenCount ?: 0, totalTokenCount ?: 0) - -internal fun com.google.firebase.vertexai.common.shared.Content.toPublic(): Content { - val returnedParts = parts.map { it.toPublic() }.filterNot { it is TextPart && it.text.isEmpty() } - // If all returned parts were text and empty, we coalesce them into a single one-character string - // part so the backend doesn't fail if we send this back as part of a multi-turn interaction. - return Content(role, returnedParts.ifEmpty { listOf(TextPart(" ")) }) -} - -internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part { - return when (this) { - is com.google.firebase.vertexai.common.shared.TextPart -> TextPart(text) - is com.google.firebase.vertexai.common.shared.InlineDataPart -> { - val data = Base64.decode(inlineData.data, BASE_64_FLAGS) - if (inlineData.mimeType.contains("image")) { - ImagePart(decodeBitmapFromImage(data)) - } else { - InlineDataPart(data, inlineData.mimeType) - } - } - is FunctionCallPart -> - com.google.firebase.vertexai.type.FunctionCallPart( - functionCall.name, - functionCall.args.orEmpty().mapValues { it.value ?: JsonNull } - ) - is FunctionResponsePart -> - com.google.firebase.vertexai.type.FunctionResponsePart( - functionResponse.name, - functionResponse.response, - ) - is com.google.firebase.vertexai.common.shared.FileDataPart -> - FileDataPart(fileData.mimeType, fileData.fileUri) - else -> - throw SerializationException( - "Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK." - ) - } -} - -internal fun com.google.firebase.vertexai.common.server.CitationSources.toPublic(): Citation { - val publicationDateAsCalendar = - publicationDate?.let { - val calendar = Calendar.getInstance() - // Internal `Date.year` uses 0 to represent not specified. We use 1 as default. - val year = if (it.year == null || it.year < 1) 1 else it.year - // Internal `Date.month` uses 0 to represent not specified, or is 1-12 as months. The month as - // expected by [Calendar] is 0-based, so we subtract 1 or use 0 as default. - val month = if (it.month == null || it.month < 1) 0 else it.month - 1 - // Internal `Date.day` uses 0 to represent not specified. We use 1 as default. - val day = if (it.day == null || it.day < 1) 1 else it.day - calendar.set(year, month, day) - calendar - } - return Citation( - title = title, - startIndex = startIndex, - endIndex = endIndex, - uri = uri, - license = license, - publicationDate = publicationDateAsCalendar - ) -} - -internal fun com.google.firebase.vertexai.common.server.CitationMetadata.toPublic() = - CitationMetadata(citationSources.map { it.toPublic() }) - -internal fun com.google.firebase.vertexai.common.server.SafetyRating.toPublic() = - SafetyRating( - category = category.toPublic(), - probability = probability.toPublic(), - probabilityScore = probabilityScore ?: 0f, - blocked = blocked, - severity = severity?.toPublic(), - severityScore = severityScore - ) - -internal fun com.google.firebase.vertexai.common.server.PromptFeedback.toPublic(): PromptFeedback { - val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() - return com.google.firebase.vertexai.type.PromptFeedback( - blockReason?.toPublic(), - safetyRatings, - blockReasonMessage - ) -} - -internal fun com.google.firebase.vertexai.common.server.FinishReason?.toPublic() = - when (this) { - null -> null - com.google.firebase.vertexai.common.server.FinishReason.MAX_TOKENS -> FinishReason.MAX_TOKENS - com.google.firebase.vertexai.common.server.FinishReason.RECITATION -> FinishReason.RECITATION - com.google.firebase.vertexai.common.server.FinishReason.SAFETY -> FinishReason.SAFETY - com.google.firebase.vertexai.common.server.FinishReason.STOP -> FinishReason.STOP - com.google.firebase.vertexai.common.server.FinishReason.OTHER -> FinishReason.OTHER - else -> FinishReason.UNKNOWN - } - -internal fun com.google.firebase.vertexai.common.shared.HarmCategory.toPublic() = - when (this) { - com.google.firebase.vertexai.common.shared.HarmCategory.HARASSMENT -> HarmCategory.HARASSMENT - com.google.firebase.vertexai.common.shared.HarmCategory.HATE_SPEECH -> HarmCategory.HATE_SPEECH - com.google.firebase.vertexai.common.shared.HarmCategory.SEXUALLY_EXPLICIT -> - HarmCategory.SEXUALLY_EXPLICIT - com.google.firebase.vertexai.common.shared.HarmCategory.DANGEROUS_CONTENT -> - HarmCategory.DANGEROUS_CONTENT - com.google.firebase.vertexai.common.shared.HarmCategory.CIVIC_INTEGRITY -> - HarmCategory.CIVIC_INTEGRITY - else -> HarmCategory.UNKNOWN - } - -internal fun com.google.firebase.vertexai.common.server.HarmProbability.toPublic() = - when (this) { - com.google.firebase.vertexai.common.server.HarmProbability.HIGH -> HarmProbability.HIGH - com.google.firebase.vertexai.common.server.HarmProbability.MEDIUM -> HarmProbability.MEDIUM - com.google.firebase.vertexai.common.server.HarmProbability.LOW -> HarmProbability.LOW - com.google.firebase.vertexai.common.server.HarmProbability.NEGLIGIBLE -> - HarmProbability.NEGLIGIBLE - else -> HarmProbability.UNKNOWN - } - -internal fun com.google.firebase.vertexai.common.server.HarmSeverity.toPublic() = - when (this) { - com.google.firebase.vertexai.common.server.HarmSeverity.HIGH -> HarmSeverity.HIGH - com.google.firebase.vertexai.common.server.HarmSeverity.MEDIUM -> HarmSeverity.MEDIUM - com.google.firebase.vertexai.common.server.HarmSeverity.LOW -> HarmSeverity.LOW - com.google.firebase.vertexai.common.server.HarmSeverity.NEGLIGIBLE -> HarmSeverity.NEGLIGIBLE - else -> HarmSeverity.UNKNOWN - } - -internal fun com.google.firebase.vertexai.common.server.BlockReason.toPublic() = - when (this) { - com.google.firebase.vertexai.common.server.BlockReason.SAFETY -> BlockReason.SAFETY - com.google.firebase.vertexai.common.server.BlockReason.OTHER -> BlockReason.OTHER - else -> BlockReason.UNKNOWN - } - -internal fun com.google.firebase.vertexai.common.GenerateContentResponse.toPublic(): - GenerateContentResponse { - return GenerateContentResponse( - candidates?.map { it.toPublic() }.orEmpty(), - promptFeedback?.toPublic(), - usageMetadata?.toPublic() - ) -} - -internal fun com.google.firebase.vertexai.common.CountTokensResponse.toPublic() = - CountTokensResponse(totalTokens, totalBillableCharacters ?: 0) - -internal fun JsonObject.toPublic() = JSONObject(toString()) - -private fun encodeBitmapToBase64Png(input: Bitmap): String { - ByteArrayOutputStream().let { - input.compress(Bitmap.CompressFormat.JPEG, 80, it) - return Base64.encodeToString(it.toByteArray(), BASE_64_FLAGS) - } -} - -private fun decodeBitmapFromImage(input: ByteArray) = - BitmapFactory.decodeByteArray(input, 0, input.size) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt index 6d8d96eb047..54cca4a80b4 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt @@ -16,7 +16,13 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer import java.util.Calendar +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonNames /** * A `Candidate` represents a single response generated by the model for a given request. @@ -32,7 +38,58 @@ internal constructor( public val safetyRatings: List, public val citationMetadata: CitationMetadata?, public val finishReason: FinishReason? -) +) { + + @Serializable + internal data class Internal( + val content: Content.Internal? = null, + val finishReason: FinishReason.Internal? = null, + val safetyRatings: List? = null, + val citationMetadata: CitationMetadata.Internal? = null, + val groundingMetadata: GroundingMetadata? = null, + ) { + internal fun toPublic(): Candidate { + val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() + val citations = citationMetadata?.toPublic() + val finishReason = finishReason?.toPublic() + + return Candidate( + this.content?.toPublic() ?: content("model") {}, + safetyRatings, + citations, + finishReason + ) + } + + @Serializable + internal data class GroundingMetadata( + @SerialName("web_search_queries") val webSearchQueries: List?, + @SerialName("search_entry_point") val searchEntryPoint: SearchEntryPoint?, + @SerialName("retrieval_queries") val retrievalQueries: List?, + @SerialName("grounding_attribution") val groundingAttribution: List?, + ) { + + @Serializable + internal data class SearchEntryPoint( + @SerialName("rendered_content") val renderedContent: String?, + @SerialName("sdk_blob") val sdkBlob: String?, + ) + + @Serializable + internal data class GroundingAttribution( + val segment: Segment, + @SerialName("confidence_score") val confidenceScore: Float?, + ) { + + @Serializable + internal data class Segment( + @SerialName("start_index") val startIndex: Int, + @SerialName("end_index") val endIndex: Int, + ) + } + } + } +} /** * An assessment of the potential harm of some generated content. @@ -55,7 +112,31 @@ internal constructor( public val blocked: Boolean? = null, public val severity: HarmSeverity? = null, public val severityScore: Float? = null -) +) { + + @Serializable + internal data class Internal + @JvmOverloads + constructor( + val category: HarmCategory.Internal, + val probability: HarmProbability.Internal, + val blocked: Boolean? = null, // TODO(): any reason not to default to false? + val probabilityScore: Float? = null, + val severity: HarmSeverity.Internal? = null, + val severityScore: Float? = null, + ) { + + internal fun toPublic() = + SafetyRating( + category = category.toPublic(), + probability = probability.toPublic(), + probabilityScore = probabilityScore ?: 0f, + blocked = blocked, + severity = severity?.toPublic(), + severityScore = severityScore + ) + } +} /** * A collection of source attributions for a piece of content. @@ -63,7 +144,16 @@ internal constructor( * @property citations A list of individual cited sources and the parts of the content to which they * apply. */ -public class CitationMetadata internal constructor(public val citations: List) +public class CitationMetadata internal constructor(public val citations: List) { + + @Serializable + internal data class Internal + @OptIn(ExperimentalSerializationApi::class) + internal constructor(@JsonNames("citations") val citationSources: List) { + + internal fun toPublic() = CitationMetadata(citationSources.map { it.toPublic() }) + } +} /** * Represents a citation of content from an external source within the model's output. @@ -89,7 +179,57 @@ internal constructor( public val uri: String? = null, public val license: String? = null, public val publicationDate: Calendar? = null -) +) { + + @Serializable + internal data class Internal( + val title: String? = null, + val startIndex: Int = 0, + val endIndex: Int, + val uri: String? = null, + val license: String? = null, + val publicationDate: Date? = null, + ) { + + internal fun toPublic(): Citation { + val publicationDateAsCalendar = + publicationDate?.let { + val calendar = Calendar.getInstance() + // Internal `Date.year` uses 0 to represent not specified. We use 1 as default. + val year = if (it.year == null || it.year < 1) 1 else it.year + // Internal `Date.month` uses 0 to represent not specified, or is 1-12 as months. The + // month as + // expected by [Calendar] is 0-based, so we subtract 1 or use 0 as default. + val month = if (it.month == null || it.month < 1) 0 else it.month - 1 + // Internal `Date.day` uses 0 to represent not specified. We use 1 as default. + val day = if (it.day == null || it.day < 1) 1 else it.day + calendar.set(year, month, day) + calendar + } + return Citation( + title = title, + startIndex = startIndex, + endIndex = endIndex, + uri = uri, + license = license, + publicationDate = publicationDateAsCalendar + ) + } + + @Serializable + internal data class Date( + /** Year of the date. Must be between 1 and 9999, or 0 for no year. */ + val year: Int? = null, + /** 1-based index for month. Must be from 1 to 12, or 0 to specify a year without a month. */ + val month: Int? = null, + /** + * Day of a month. Must be from 1 to 31 and valid for the year and month, or 0 to specify a + * year by itself or a year and month where the day isn't significant. + */ + val day: Int? = null, + ) + } +} /** * Represents the reason why the model stopped generating content. @@ -98,6 +238,29 @@ internal constructor( * @property ordinal The ordinal value of the finish reason. */ public class FinishReason private constructor(public val name: String, public val ordinal: Int) { + + @Serializable(Internal.Serializer::class) + internal enum class Internal { + UNKNOWN, + @SerialName("FINISH_REASON_UNSPECIFIED") UNSPECIFIED, + STOP, + MAX_TOKENS, + SAFETY, + RECITATION, + OTHER; + + internal object Serializer : KSerializer by FirstOrdinalSerializer(Internal::class) + + internal fun toPublic() = + when (this) { + MAX_TOKENS -> FinishReason.MAX_TOKENS + RECITATION -> FinishReason.RECITATION + SAFETY -> FinishReason.SAFETY + STOP -> FinishReason.STOP + OTHER -> FinishReason.OTHER + else -> FinishReason.UNKNOWN + } + } public companion object { /** A new and not yet supported value. */ @JvmField public val UNKNOWN: FinishReason = FinishReason("UNKNOWN", 0) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt index ec3e9555741..241d0becfe6 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt @@ -17,6 +17,9 @@ package com.google.firebase.vertexai.type import android.graphics.Bitmap +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.Serializable /** * Represents content sent to and received from the model. @@ -76,6 +79,24 @@ constructor(public val role: String? = "user", public val parts: List) { /** Returns a new [Content] using the defined [role] and [parts]. */ public fun build(): Content = Content(role, parts) } + + internal fun toInternal() = Internal(this.role ?: "user", this.parts.map { it.toInternal() }) + + @ExperimentalSerializationApi + @Serializable + internal data class Internal( + @EncodeDefault val role: String? = "user", + val parts: List + ) { + internal fun toPublic(): Content { + val returnedParts = + parts.map { it.toPublic() }.filterNot { it is TextPart && it.text.isEmpty() } + // If all returned parts were text and empty, we coalesce them into a single one-character + // string + // part so the backend doesn't fail if we send this back as part of a multi-turn interaction. + return Content(role, returnedParts.ifEmpty { listOf(TextPart(" ")) }) + } + } } /** diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt index 2835deba6f7..4c05521ad65 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt @@ -16,6 +16,8 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable + /** * The model's response to a count tokens request. * @@ -36,4 +38,13 @@ public class CountTokensResponse( public operator fun component1(): Int = totalTokens public operator fun component2(): Int? = totalBillableCharacters + + @Serializable + internal data class Internal(val totalTokens: Int, val totalBillableCharacters: Int? = null) : + Response { + + internal fun toPublic(): CountTokensResponse { + return CountTokensResponse(totalTokens, totalBillableCharacters ?: 0) + } + } } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt index a3bd95e15ab..4f4ca954f36 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt @@ -18,7 +18,6 @@ package com.google.firebase.vertexai.type import com.google.firebase.vertexai.FirebaseVertexAI import com.google.firebase.vertexai.common.FirebaseCommonAIException -import com.google.firebase.vertexai.internal.util.toPublic import kotlinx.coroutines.TimeoutCancellationException /** Parent class for any errors that occur from the [FirebaseVertexAI] SDK. */ diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt index a2ea9b1d01e..ee557556bbc 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt @@ -16,6 +16,9 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** * The configuration that specifies the function calling behavior. * @@ -42,7 +45,21 @@ internal constructor( * The model will never predict a function call to answer a query. This can also be achieved by * not passing any tools to the model. */ - NONE + NONE, + } + + @Serializable + internal data class Internal( + val mode: Mode, + @SerialName("allowed_function_names") val allowedFunctionNames: List? = null + ) { + @Serializable + enum class Mode { + @SerialName("MODE_UNSPECIFIED") UNSPECIFIED, + AUTO, + ANY, + NONE, + } } public companion object { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt index 672293bb559..8813de18b43 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt @@ -16,6 +16,8 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable + /** * Defines a function that the model can use as a tool. * @@ -58,4 +60,13 @@ public class FunctionDeclaration( ) { internal val schema: Schema = Schema.obj(properties = parameters, optionalProperties = optionalParameters, nullable = false) + + internal fun toInternal() = Internal(name, "", schema.toInternal()) + + @Serializable + internal data class Internal( + val name: String, + val description: String, + val parameters: Schema.Internal + ) } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerateContentResponse.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerateContentResponse.kt index 85891457b78..00395252914 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerateContentResponse.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerateContentResponse.kt @@ -16,6 +16,8 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable + /** * A response from the model. * @@ -41,4 +43,19 @@ public class GenerateContentResponse( public val functionCalls: List by lazy { candidates.first().content.parts.filterIsInstance() } + + @Serializable + internal data class Internal( + val candidates: List? = null, + val promptFeedback: PromptFeedback.Internal? = null, + val usageMetadata: UsageMetadata.Internal? = null, + ) : Response { + internal fun toPublic(): GenerateContentResponse { + return GenerateContentResponse( + candidates?.map { it.toPublic() }.orEmpty(), + promptFeedback?.toPublic(), + usageMetadata?.toPublic() + ) + } + } } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt index 8bf8d7a1ac7..4abec8a260d 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt @@ -16,6 +16,9 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** * Configuration parameters to use for content generation. * @@ -142,6 +145,34 @@ private constructor( ) } + internal fun toInternal() = + Internal( + temperature = temperature, + topP = topP, + topK = topK, + candidateCount = candidateCount, + maxOutputTokens = maxOutputTokens, + stopSequences = stopSequences, + frequencyPenalty = frequencyPenalty, + presencePenalty = presencePenalty, + responseMimeType = responseMimeType, + responseSchema = responseSchema?.toInternal() + ) + + @Serializable + internal data class Internal( + val temperature: Float?, + @SerialName("top_p") val topP: Float?, + @SerialName("top_k") val topK: Int?, + @SerialName("candidate_count") val candidateCount: Int?, + @SerialName("max_output_tokens") val maxOutputTokens: Int?, + @SerialName("stop_sequences") val stopSequences: List?, + @SerialName("response_mime_type") val responseMimeType: String? = null, + @SerialName("presence_penalty") val presencePenalty: Float? = null, + @SerialName("frequency_penalty") val frequencyPenalty: Float? = null, + @SerialName("response_schema") val responseSchema: Schema.Internal? = null, + ) + public companion object { /** diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt index e743964c64a..1bd16949b20 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt @@ -16,11 +16,28 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.common.makeMissingCaseException +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** * Specifies how the block method computes the score that will be compared against the * [HarmBlockThreshold] in [SafetySetting]. */ public class HarmBlockMethod private constructor(public val ordinal: Int) { + internal fun toInternal() = + when (this) { + SEVERITY -> Internal.SEVERITY + PROBABILITY -> Internal.PROBABILITY + else -> throw makeMissingCaseException("HarmBlockMethod", ordinal) + } + + @Serializable + internal enum class Internal { + @SerialName("HARM_BLOCK_METHOD_UNSPECIFIED") UNSPECIFIED, + SEVERITY, + PROBABILITY, + } public companion object { /** * The harm block method uses both probability and severity scores. See [HarmSeverity] and diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt index 073416112ab..1b3233bda2a 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt @@ -16,8 +16,31 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.common.makeMissingCaseException +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** Represents the threshold for a [HarmCategory] to be allowed by [SafetySetting]. */ public class HarmBlockThreshold private constructor(public val ordinal: Int) { + + internal fun toInternal() = + when (this) { + NONE -> Internal.BLOCK_NONE + ONLY_HIGH -> Internal.BLOCK_ONLY_HIGH + MEDIUM_AND_ABOVE -> Internal.BLOCK_MEDIUM_AND_ABOVE + LOW_AND_ABOVE -> Internal.BLOCK_LOW_AND_ABOVE + else -> throw makeMissingCaseException("HarmBlockThreshold", ordinal) + } + + @Serializable + internal enum class Internal { + @SerialName("HARM_BLOCK_THRESHOLD_UNSPECIFIED") UNSPECIFIED, + BLOCK_LOW_AND_ABOVE, + BLOCK_MEDIUM_AND_ABOVE, + BLOCK_ONLY_HIGH, + BLOCK_NONE, + } + public companion object { /** Content with negligible harm is allowed. */ @JvmField public val LOW_AND_ABOVE: HarmBlockThreshold = HarmBlockThreshold(0) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt index d19de2e1568..2429688b02b 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt @@ -16,8 +16,45 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.common.makeMissingCaseException +import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** Category for a given harm rating. */ public class HarmCategory private constructor(public val ordinal: Int) { + internal fun toInternal() = + when (this) { + HARASSMENT -> Internal.HARASSMENT + HATE_SPEECH -> Internal.HATE_SPEECH + SEXUALLY_EXPLICIT -> Internal.SEXUALLY_EXPLICIT + DANGEROUS_CONTENT -> Internal.DANGEROUS_CONTENT + CIVIC_INTEGRITY -> Internal.CIVIC_INTEGRITY + UNKNOWN -> Internal.UNKNOWN + else -> throw makeMissingCaseException("HarmCategory", ordinal) + } + @Serializable(Internal.Serializer::class) + internal enum class Internal { + UNKNOWN, + @SerialName("HARM_CATEGORY_HARASSMENT") HARASSMENT, + @SerialName("HARM_CATEGORY_HATE_SPEECH") HATE_SPEECH, + @SerialName("HARM_CATEGORY_SEXUALLY_EXPLICIT") SEXUALLY_EXPLICIT, + @SerialName("HARM_CATEGORY_DANGEROUS_CONTENT") DANGEROUS_CONTENT, + @SerialName("HARM_CATEGORY_CIVIC_INTEGRITY") CIVIC_INTEGRITY; + + internal object Serializer : KSerializer by FirstOrdinalSerializer(Internal::class) + + internal fun toPublic() = + when (this) { + HARASSMENT -> HarmCategory.HARASSMENT + HATE_SPEECH -> HarmCategory.HATE_SPEECH + SEXUALLY_EXPLICIT -> HarmCategory.SEXUALLY_EXPLICIT + DANGEROUS_CONTENT -> HarmCategory.DANGEROUS_CONTENT + CIVIC_INTEGRITY -> HarmCategory.CIVIC_INTEGRITY + else -> HarmCategory.UNKNOWN + } + } public companion object { /** A new and not yet supported value. */ @JvmField public val UNKNOWN: HarmCategory = HarmCategory(0) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt index d4208f7bf85..3d13e177819 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt @@ -16,8 +16,33 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** Represents the probability that some [HarmCategory] is applicable in a [SafetyRating]. */ public class HarmProbability private constructor(public val ordinal: Int) { + @Serializable(Internal.Serializer::class) + internal enum class Internal { + UNKNOWN, + @SerialName("HARM_PROBABILITY_UNSPECIFIED") UNSPECIFIED, + NEGLIGIBLE, + LOW, + MEDIUM, + HIGH; + + internal object Serializer : KSerializer by FirstOrdinalSerializer(Internal::class) + + internal fun toPublic() = + when (this) { + HIGH -> HarmProbability.HIGH + MEDIUM -> HarmProbability.MEDIUM + LOW -> HarmProbability.LOW + NEGLIGIBLE -> HarmProbability.NEGLIGIBLE + else -> HarmProbability.UNKNOWN + } + } public companion object { /** A new and not yet supported value. */ @JvmField public val UNKNOWN: HarmProbability = HarmProbability(0) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt index 40fe73ca906..0d0a39f2ac9 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt @@ -16,8 +16,33 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** Represents the severity of a [HarmCategory] being applicable in a [SafetyRating]. */ public class HarmSeverity private constructor(public val ordinal: Int) { + @Serializable(Internal.Serializer::class) + internal enum class Internal { + UNKNOWN, + @SerialName("HARM_SEVERITY_UNSPECIFIED") UNSPECIFIED, + @SerialName("HARM_SEVERITY_NEGLIGIBLE") NEGLIGIBLE, + @SerialName("HARM_SEVERITY_LOW") LOW, + @SerialName("HARM_SEVERITY_MEDIUM") MEDIUM, + @SerialName("HARM_SEVERITY_HIGH") HIGH; + + internal object Serializer : KSerializer by FirstOrdinalSerializer(Internal::class) + + internal fun toPublic() = + when (this) { + HIGH -> HarmSeverity.HIGH + MEDIUM -> HarmSeverity.MEDIUM + LOW -> HarmSeverity.LOW + NEGLIGIBLE -> HarmSeverity.NEGLIGIBLE + else -> HarmSeverity.UNKNOWN + } + } public companion object { /** A new and not yet supported value. */ @JvmField public val UNKNOWN: HarmSeverity = HarmSeverity(0) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt index 41ddfcfbe41..cfffe6885b0 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt @@ -17,15 +17,27 @@ package com.google.firebase.vertexai.type import android.graphics.Bitmap +import android.graphics.BitmapFactory +import java.io.ByteArrayOutputStream +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.JsonContentPolymorphicSerializer import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.jsonObject import org.json.JSONObject /** Interface representing data sent to and received from requests. */ -public interface Part +public interface Part {} /** Represents text or string based data sent to and received from requests. */ -public class TextPart(public val text: String) : Part +public class TextPart(public val text: String) : Part { + + @Serializable internal data class Internal(val text: String) : InternalPart +} /** * Represents image data sent to and received from requests. When this is sent to the server it is @@ -42,7 +54,16 @@ public class ImagePart(public val image: Bitmap) : Part * @param mimeType an IANA standard MIME type. For supported values, see the * [Vertex AI documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#media_requirements) */ -public class InlineDataPart(public val inlineData: ByteArray, public val mimeType: String) : Part +public class InlineDataPart(public val inlineData: ByteArray, public val mimeType: String) : Part { + + @Serializable + internal data class Internal(@SerialName("inline_data") val inlineData: InlineData) : + InternalPart { + + @Serializable + internal data class InlineData(@SerialName("mime_type") val mimeType: String, val data: Base64) + } +} /** * Represents function call name and params received from requests. @@ -51,7 +72,15 @@ public class InlineDataPart(public val inlineData: ByteArray, public val mimeTyp * @param args the function parameters and values as a [Map] */ public class FunctionCallPart(public val name: String, public val args: Map) : - Part + Part { + + @Serializable + internal data class Internal(val functionCall: FunctionCall) : InternalPart { + + @Serializable + internal data class FunctionCall(val name: String, val args: Map? = null) + } +} /** * Represents function call output to be returned to the model when it requests a function call. @@ -59,7 +88,14 @@ public class FunctionCallPart(public val name: String, public val args: Map(InternalPart::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + return when { + "text" in jsonObject -> TextPart.Internal.serializer() + "functionCall" in jsonObject -> FunctionCallPart.Internal.serializer() + "functionResponse" in jsonObject -> FunctionResponsePart.Internal.serializer() + "inlineData" in jsonObject -> InlineDataPart.Internal.serializer() + "fileData" in jsonObject -> FileDataPart.Internal.serializer() + else -> throw SerializationException("Unknown Part type") + } + } +} + +internal fun Part.toInternal(): InternalPart { + return when (this) { + is TextPart -> TextPart.Internal(text) + is ImagePart -> + InlineDataPart.Internal( + InlineDataPart.Internal.InlineData("image/jpeg", encodeBitmapToBase64Png(image)) + ) + is InlineDataPart -> + InlineDataPart.Internal( + InlineDataPart.Internal.InlineData( + mimeType, + android.util.Base64.encodeToString(inlineData, BASE_64_FLAGS) + ) + ) + is FunctionCallPart -> + FunctionCallPart.Internal(FunctionCallPart.Internal.FunctionCall(name, args)) + is FunctionResponsePart -> + FunctionResponsePart.Internal(FunctionResponsePart.Internal.FunctionResponse(name, response)) + is FileDataPart -> + FileDataPart.Internal(FileDataPart.Internal.FileData(mimeType = mimeType, fileUri = uri)) + else -> + throw com.google.firebase.vertexai.type.SerializationException( + "The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet." + ) + } +} + +private fun encodeBitmapToBase64Png(input: Bitmap): String { + ByteArrayOutputStream().let { + input.compress(Bitmap.CompressFormat.JPEG, 80, it) + return android.util.Base64.encodeToString(it.toByteArray(), BASE_64_FLAGS) + } +} + +internal fun InternalPart.toPublic(): Part { + return when (this) { + is TextPart.Internal -> TextPart(text) + is InlineDataPart.Internal -> { + val data = android.util.Base64.decode(inlineData.data, BASE_64_FLAGS) + if (inlineData.mimeType.contains("image")) { + ImagePart(decodeBitmapFromImage(data)) + } else { + InlineDataPart(data, inlineData.mimeType) + } + } + is FunctionCallPart.Internal -> + FunctionCallPart( + functionCall.name, + functionCall.args.orEmpty().mapValues { it.value ?: JsonNull } + ) + is FunctionResponsePart.Internal -> + FunctionResponsePart( + functionResponse.name, + functionResponse.response, + ) + is FileDataPart.Internal -> FileDataPart(fileData.mimeType, fileData.fileUri) + else -> + throw com.google.firebase.vertexai.type.SerializationException( + "Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK." + ) + } +} + +private fun decodeBitmapFromImage(input: ByteArray) = + BitmapFactory.decodeByteArray(input, 0, input.size) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt index b4d06d04b8a..5f0e0edd017 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt @@ -16,6 +16,11 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** * Feedback on the prompt provided in the request. * @@ -27,10 +32,41 @@ public class PromptFeedback( public val blockReason: BlockReason?, public val safetyRatings: List, public val blockReasonMessage: String? -) +) { + + @Serializable + internal data class Internal( + val blockReason: BlockReason.Internal? = null, + val safetyRatings: List? = null, + val blockReasonMessage: String? = null, + ) { + + internal fun toPublic(): PromptFeedback { + val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() + return PromptFeedback(blockReason?.toPublic(), safetyRatings, blockReasonMessage) + } + } +} /** Describes why content was blocked. */ public class BlockReason private constructor(public val name: String, public val ordinal: Int) { + + @Serializable(Internal.Serializer::class) + internal enum class Internal { + UNKNOWN, + @SerialName("BLOCKED_REASON_UNSPECIFIED") UNSPECIFIED, + SAFETY, + OTHER; + + internal object Serializer : KSerializer by FirstOrdinalSerializer(Internal::class) + + internal fun toPublic() = + when (this) { + SAFETY -> BlockReason.SAFETY + OTHER -> BlockReason.OTHER + else -> BlockReason.UNKNOWN + } + } public companion object { /** A new and not yet supported value. */ @JvmField public val UNKNOWN: BlockReason = BlockReason("UNKNOWN", 0) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt index 68d7f93aa99..8095c42c532 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt @@ -16,6 +16,8 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable + /** * A configuration for a [HarmBlockThreshold] of some [HarmCategory] allowed and blocked in * responses. @@ -29,4 +31,14 @@ public class SafetySetting( internal val harmCategory: HarmCategory, internal val threshold: HarmBlockThreshold, internal val method: HarmBlockMethod? = null, -) +) { + internal fun toInternal() = + Internal(harmCategory.toInternal(), threshold.toInternal(), method?.toInternal()) + + @Serializable + internal data class Internal( + val category: HarmCategory.Internal, + val threshold: HarmBlockThreshold.Internal, + val method: HarmBlockMethod.Internal? = null, + ) +} diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt index b6f69e51d49..869d83b0eb9 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt @@ -16,6 +16,8 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable + public abstract class StringFormat private constructor(internal val value: String) { public class Custom(value: String) : StringFormat(value) } @@ -238,4 +240,27 @@ internal constructor( type = "STRING", ) } + + internal fun toInternal(): Internal = + Internal( + type, + description, + format, + nullable, + enum, + properties?.mapValues { it.value.toInternal() }, + required, + items?.toInternal(), + ) + @Serializable + internal data class Internal( + val type: String, + val description: String? = null, + val format: String? = null, + val nullable: Boolean? = false, + val enum: List? = null, + val properties: Map? = null, + val required: List? = null, + val items: Internal? = null, + ) } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt index 41cbf99f6c4..e62e02f55b1 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt @@ -16,6 +16,9 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + /** * Contains a set of function declarations that the model has access to. These can be used to gather * information, or complete tasks @@ -24,6 +27,13 @@ package com.google.firebase.vertexai.type */ public class Tool internal constructor(internal val functionDeclarations: List?) { + internal fun toInternal() = Internal(functionDeclarations?.map { it.toInternal() } ?: emptyList()) + @Serializable + internal data class Internal( + val functionDeclarations: List? = null, + // This is a json object because it is not possible to make a data class with no parameters. + val codeExecution: JsonObject? = null, + ) public companion object { /** diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt index ee26f6b0f57..99769ed46b6 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt @@ -16,10 +16,34 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** * Contains configuration for the function calling tools of the model. This can be used to change * when the model can predict function calls. * * @param functionCallingConfig The config for function calling */ -public class ToolConfig(internal val functionCallingConfig: FunctionCallingConfig?) +public class ToolConfig(internal val functionCallingConfig: FunctionCallingConfig?) { + + internal fun toInternal() = + Internal( + functionCallingConfig?.let { + FunctionCallingConfig.Internal( + when (it.mode) { + FunctionCallingConfig.Mode.ANY -> FunctionCallingConfig.Internal.Mode.ANY + FunctionCallingConfig.Mode.AUTO -> FunctionCallingConfig.Internal.Mode.AUTO + FunctionCallingConfig.Mode.NONE -> FunctionCallingConfig.Internal.Mode.NONE + }, + it.allowedFunctionNames + ) + } + ) + + @Serializable + internal data class Internal( + @SerialName("function_calling_config") + val functionCallingConfig: FunctionCallingConfig.Internal? + ) +} diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Type.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Type.kt index ff33240aa24..a35000c7da5 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Type.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Type.kt @@ -15,3 +15,33 @@ */ package com.google.firebase.vertexai.type + +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import org.json.JSONObject + +internal sealed interface Response + +@Serializable +internal data class GRpcErrorResponse(val error: GRpcError) : Response { + + @Serializable + internal data class GRpcError( + val code: Int, + val message: String, + val details: List? = null + ) { + + @Serializable + internal data class GRpcErrorDetails( + val reason: String? = null, + val domain: String? = null, + val metadata: Map? = null + ) + } +} + +internal fun JSONObject.toInternal() = Json.decodeFromString(toString()) + +internal fun JsonObject.toPublic() = JSONObject(toString()) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt index 21da0255cb9..54f5cbd89b7 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt @@ -16,6 +16,8 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable + /** * Usage metadata about response(s). * @@ -27,4 +29,16 @@ public class UsageMetadata( public val promptTokenCount: Int, public val candidatesTokenCount: Int?, public val totalTokenCount: Int -) +) { + + @Serializable + internal data class Internal( + val promptTokenCount: Int? = null, + val candidatesTokenCount: Int? = null, + val totalTokenCount: Int? = null, + ) { + + internal fun toPublic(): UsageMetadata = + UsageMetadata(promptTokenCount ?: 0, candidatesTokenCount ?: 0, totalTokenCount ?: 0) + } +} diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt index 8b668371a31..67d41c9b5d6 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt @@ -17,14 +17,14 @@ package com.google.firebase.vertexai import com.google.firebase.vertexai.common.APIController -import com.google.firebase.vertexai.common.GenerateContentResponse import com.google.firebase.vertexai.common.JSON -import com.google.firebase.vertexai.common.server.Candidate -import com.google.firebase.vertexai.common.shared.Content -import com.google.firebase.vertexai.common.shared.TextPart import com.google.firebase.vertexai.common.util.doBlocking +import com.google.firebase.vertexai.type.Candidate +import com.google.firebase.vertexai.type.Content +import com.google.firebase.vertexai.type.GenerateContentResponse import com.google.firebase.vertexai.type.RequestOptions import com.google.firebase.vertexai.type.ServerException +import com.google.firebase.vertexai.type.TextPart import com.google.firebase.vertexai.type.content import io.kotest.assertions.json.shouldContainJsonKey import io.kotest.assertions.json.shouldContainJsonKeyValue @@ -129,7 +129,9 @@ internal class GenerativeModelTesting { private fun generateContentResponseAsJsonString(text: String): String { return JSON.encodeToString( - GenerateContentResponse(listOf(Candidate(Content(parts = listOf(TextPart(text)))))) + GenerateContentResponse.Internal( + listOf(Candidate.Internal(Content.Internal(parts = listOf(TextPart.Internal(text))))) + ) ) } } diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/SchemaTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/SchemaTests.kt index 747f65ac168..4701d516ff5 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/SchemaTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/SchemaTests.kt @@ -16,7 +16,6 @@ package com.google.firebase.vertexai -import com.google.firebase.vertexai.internal.util.toInternal import com.google.firebase.vertexai.type.Schema import com.google.firebase.vertexai.type.StringFormat import io.kotest.assertions.json.shouldEqualJson diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt index 8937b13569b..463dbe773f7 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt @@ -17,16 +17,17 @@ package com.google.firebase.vertexai.common import com.google.firebase.vertexai.BuildConfig -import com.google.firebase.vertexai.common.client.FunctionCallingConfig -import com.google.firebase.vertexai.common.client.Tool -import com.google.firebase.vertexai.common.client.ToolConfig -import com.google.firebase.vertexai.common.shared.Content -import com.google.firebase.vertexai.common.shared.TextPart import com.google.firebase.vertexai.common.util.commonTest import com.google.firebase.vertexai.common.util.createResponses import com.google.firebase.vertexai.common.util.doBlocking import com.google.firebase.vertexai.common.util.prepareStreamingResponse +import com.google.firebase.vertexai.type.Content +import com.google.firebase.vertexai.type.CountTokensResponse +import com.google.firebase.vertexai.type.FunctionCallingConfig import com.google.firebase.vertexai.type.RequestOptions +import com.google.firebase.vertexai.type.TextPart +import com.google.firebase.vertexai.type.Tool +import com.google.firebase.vertexai.type.ToolConfig import io.kotest.assertions.json.shouldContainJsonKey import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe @@ -140,7 +141,7 @@ internal class RequestFormatTests { @Test fun `client id header is set correctly in the request`() = doBlocking { - val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10)) + val response = JSON.encodeToString(CountTokensResponse.Internal(totalTokens = 10)) val mockEngine = MockEngine { respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) } @@ -183,11 +184,11 @@ internal class RequestFormatTests { .generateContentStream( GenerateContentRequest( model = "unused", - contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))), + contents = listOf(Content.Internal(parts = listOf(TextPart.Internal("Arbitrary")))), toolConfig = - ToolConfig( - FunctionCallingConfig( - mode = FunctionCallingConfig.Mode.ANY, + ToolConfig.Internal( + FunctionCallingConfig.Internal( + mode = FunctionCallingConfig.Internal.Mode.ANY, allowedFunctionNames = listOf("allowedFunctionName") ) ) @@ -205,7 +206,7 @@ internal class RequestFormatTests { @Test fun `headers from HeaderProvider are added to the request`() = doBlocking { - val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10)) + val response = JSON.encodeToString(CountTokensResponse.Internal(totalTokens = 10)) val mockEngine = MockEngine { respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) } @@ -237,7 +238,7 @@ internal class RequestFormatTests { @Test fun `headers from HeaderProvider are ignored if timeout`() = doBlocking { - val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10)) + val response = JSON.encodeToString(CountTokensResponse.Internal(totalTokens = 10)) val mockEngine = MockEngine { respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) } @@ -291,8 +292,8 @@ internal class RequestFormatTests { .generateContentStream( GenerateContentRequest( model = "unused", - contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))), - tools = listOf(Tool(codeExecution = JsonObject(emptyMap()))), + contents = listOf(Content.Internal(parts = listOf(TextPart.Internal("Arbitrary")))), + tools = listOf(Tool.Internal(codeExecution = JsonObject(emptyMap()))), ) ) .collect { channel.close() } @@ -351,7 +352,7 @@ internal class ModelNamingTests(private val modelName: String, private val actua internal fun textGenerateContentRequest(prompt: String) = GenerateContentRequest( model = "unused", - contents = listOf(Content(parts = listOf(TextPart(prompt)))), + contents = listOf(Content.Internal(parts = listOf(TextPart.Internal(prompt)))), ) internal fun textCountTokenRequest(prompt: String) = diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/EnumUpdateTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/EnumUpdateTests.kt index ddee4dbabf1..769adbd4cd8 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/EnumUpdateTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/EnumUpdateTests.kt @@ -16,7 +16,6 @@ package com.google.firebase.vertexai.common -import com.google.firebase.vertexai.internal.util.toInternal import com.google.firebase.vertexai.type.HarmBlockMethod import com.google.firebase.vertexai.type.HarmBlockThreshold import com.google.firebase.vertexai.type.HarmCategory diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt index 9d470c95ea6..2d29ad38ba7 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt @@ -16,11 +16,11 @@ package com.google.firebase.vertexai.common -import com.google.firebase.vertexai.common.server.BlockReason -import com.google.firebase.vertexai.common.server.FinishReason -import com.google.firebase.vertexai.common.shared.HarmCategory -import com.google.firebase.vertexai.common.shared.TextPart import com.google.firebase.vertexai.common.util.goldenStreamingFile +import com.google.firebase.vertexai.type.BlockReason +import com.google.firebase.vertexai.type.FinishReason +import com.google.firebase.vertexai.type.HarmCategory +import com.google.firebase.vertexai.type.TextPart import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.shouldBe @@ -43,7 +43,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val responseList = responses.toList() responseList.isEmpty() shouldBe false - responseList.first().candidates?.first()?.finishReason shouldBe FinishReason.STOP + responseList.first().candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP responseList.first().candidates?.first()?.content?.parts?.isEmpty() shouldBe false responseList.first().candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } @@ -58,7 +58,7 @@ internal class StreamingSnapshotTests { val responseList = responses.toList() responseList.isEmpty() shouldBe false responseList.forEach { - it.candidates?.first()?.finishReason shouldBe FinishReason.STOP + it.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP it.candidates?.first()?.content?.parts?.isEmpty() shouldBe false it.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } @@ -75,7 +75,7 @@ internal class StreamingSnapshotTests { responseList.isEmpty() shouldBe false responseList.any { it.candidates?.any { - it.safetyRatings?.any { it.category == HarmCategory.UNKNOWN } ?: false + it.safetyRatings?.any { it.category == HarmCategory.Internal.UNKNOWN } ?: false } ?: false } shouldBe true @@ -91,7 +91,8 @@ internal class StreamingSnapshotTests { val responseList = responses.toList() responseList.isEmpty() shouldBe false - val part = responseList.first().candidates?.first()?.content?.parts?.first() as? TextPart + val part = + responseList.first().candidates?.first()?.content?.parts?.first() as? TextPart.Internal part.shouldNotBeNull() part.text shouldContain "\"" } @@ -104,7 +105,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.promptFeedback?.blockReason shouldBe BlockReason.SAFETY + exception.response.promptFeedback?.blockReason shouldBe BlockReason.Internal.SAFETY } } @@ -131,7 +132,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.SAFETY + exception.response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.SAFETY } } @@ -170,7 +171,8 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.RECITATION + exception.response.candidates?.first()?.finishReason shouldBe + FinishReason.Internal.RECITATION } } diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt index 66e6a3f53a5..33ebdda5322 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt @@ -16,15 +16,15 @@ package com.google.firebase.vertexai.common -import com.google.firebase.vertexai.common.server.BlockReason -import com.google.firebase.vertexai.common.server.FinishReason -import com.google.firebase.vertexai.common.server.HarmProbability -import com.google.firebase.vertexai.common.server.HarmSeverity -import com.google.firebase.vertexai.common.shared.FunctionCallPart -import com.google.firebase.vertexai.common.shared.HarmCategory -import com.google.firebase.vertexai.common.shared.TextPart import com.google.firebase.vertexai.common.util.goldenUnaryFile import com.google.firebase.vertexai.common.util.shouldNotBeNullOrEmpty +import com.google.firebase.vertexai.type.BlockReason +import com.google.firebase.vertexai.type.FinishReason +import com.google.firebase.vertexai.type.FunctionCallPart +import com.google.firebase.vertexai.type.HarmCategory +import com.google.firebase.vertexai.type.HarmProbability +import com.google.firebase.vertexai.type.HarmSeverity +import com.google.firebase.vertexai.type.TextPart import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.collections.shouldNotBeEmpty import io.kotest.matchers.nulls.shouldNotBeNull @@ -53,7 +53,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } @@ -67,7 +67,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } @@ -81,7 +81,7 @@ internal class UnarySnapshotTests { response.candidates?.isNullOrEmpty() shouldBe false val candidate = response.candidates?.first() - candidate?.safetyRatings?.any { it.category == HarmCategory.UNKNOWN } shouldBe true + candidate?.safetyRatings?.any { it.category == HarmCategory.Internal.UNKNOWN } shouldBe true } } @@ -94,12 +94,12 @@ internal class UnarySnapshotTests { response.candidates?.isEmpty() shouldBe false response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false response.candidates?.first()?.safetyRatings?.all { - it.probability == HarmProbability.NEGLIGIBLE + it.probability == HarmProbability.Internal.NEGLIGIBLE } shouldBe true response.candidates?.first()?.safetyRatings?.all { it.probabilityScore != null } shouldBe true response.candidates?.first()?.safetyRatings?.all { - it.severity == HarmSeverity.NEGLIGIBLE + it.severity == HarmSeverity.Internal.NEGLIGIBLE } shouldBe true response.candidates?.first()?.safetyRatings?.all { it.severityScore != null } shouldBe true } @@ -111,7 +111,7 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { shouldThrow { apiController.generateContent(textGenerateContentRequest("prompt")) - } should { it.response.promptFeedback?.blockReason shouldBe BlockReason.SAFETY } + } should { it.response.promptFeedback?.blockReason shouldBe BlockReason.Internal.SAFETY } } } @@ -153,7 +153,7 @@ internal class UnarySnapshotTests { shouldThrow { apiController.generateContent(textGenerateContentRequest("prompt")) } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.SAFETY + exception.response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.SAFETY } } @@ -191,7 +191,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP response.usageMetadata shouldNotBe null response.usageMetadata?.totalTokenCount shouldBe 363 } @@ -204,7 +204,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP response.usageMetadata shouldNotBe null response.usageMetadata?.promptTokenCount shouldBe 6 response.usageMetadata?.totalTokenCount shouldBe null @@ -231,7 +231,12 @@ internal class UnarySnapshotTests { response.candidates?.isEmpty() shouldBe false with( - response.candidates?.first()?.content?.parts?.first()?.shouldBeInstanceOf() + response.candidates + ?.first() + ?.content + ?.parts + ?.first() + ?.shouldBeInstanceOf() ) { shouldNotBeNull() JSON.decodeFromString>(text).shouldNotBeEmpty() @@ -315,7 +320,8 @@ internal class UnarySnapshotTests { goldenUnaryFile("success-function-call-null.json") { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - val callPart = (response.candidates!!.first().content!!.parts.first() as FunctionCallPart) + val callPart = + (response.candidates!!.first().content!!.parts.first() as FunctionCallPart.Internal) callPart.functionCall.args shouldNotBe null callPart.functionCall.args?.get("season") shouldBe null @@ -333,7 +339,7 @@ internal class UnarySnapshotTests { content.let { it.shouldNotBeNull() it.parts.shouldNotBeEmpty() - it.parts.first().shouldBeInstanceOf() + it.parts.first().shouldBeInstanceOf() } callPart.functionCall.args shouldNotBe null @@ -349,7 +355,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) val content = response.candidates.shouldNotBeNullOrEmpty().first().content content.shouldNotBeNull() - val callPart = content.parts.shouldNotBeNullOrEmpty().first() as FunctionCallPart + val callPart = content.parts.shouldNotBeNullOrEmpty().first() as FunctionCallPart.Internal callPart.functionCall.name shouldBe "current_time" callPart.functionCall.args shouldBe null diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt index 8d1e3bf9f00..5e52b1827b0 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt @@ -20,12 +20,12 @@ package com.google.firebase.vertexai.common.util import com.google.firebase.vertexai.common.APIController import com.google.firebase.vertexai.common.GenerateContentRequest -import com.google.firebase.vertexai.common.GenerateContentResponse import com.google.firebase.vertexai.common.JSON -import com.google.firebase.vertexai.common.server.Candidate -import com.google.firebase.vertexai.common.shared.Content -import com.google.firebase.vertexai.common.shared.TextPart +import com.google.firebase.vertexai.type.Candidate +import com.google.firebase.vertexai.type.Content +import com.google.firebase.vertexai.type.GenerateContentResponse import com.google.firebase.vertexai.type.RequestOptions +import com.google.firebase.vertexai.type.TextPart import io.kotest.matchers.collections.shouldNotBeEmpty import io.kotest.matchers.nulls.shouldNotBeNull import io.ktor.client.engine.mock.MockEngine @@ -43,15 +43,16 @@ import kotlinx.serialization.encodeToString private val TEST_CLIENT_ID = "genai-android/test" -internal fun prepareStreamingResponse(response: List): List = - response.map { "data: ${JSON.encodeToString(it)}$SSE_SEPARATOR".toByteArray() } +internal fun prepareStreamingResponse( + response: List +): List = response.map { "data: ${JSON.encodeToString(it)}$SSE_SEPARATOR".toByteArray() } -internal fun prepareResponse(response: GenerateContentResponse) = +internal fun prepareResponse(response: GenerateContentResponse.Internal) = JSON.encodeToString(response).toByteArray() @OptIn(ExperimentalSerializationApi::class) internal fun createRequest(vararg text: String): GenerateContentRequest { - val contents = text.map { Content(parts = listOf(TextPart(it))) } + val contents = text.map { Content.Internal(parts = listOf(TextPart.Internal(it))) } return GenerateContentRequest("gemini", contents) } @@ -59,10 +60,11 @@ internal fun createRequest(vararg text: String): GenerateContentRequest { internal fun createResponse(text: String) = createResponses(text).single() @OptIn(ExperimentalSerializationApi::class) -internal fun createResponses(vararg text: String): List { - val candidates = text.map { Candidate(Content(parts = listOf(TextPart(it)))) } +internal fun createResponses(vararg text: String): List { + val candidates = + text.map { Candidate.Internal(Content.Internal(parts = listOf(TextPart.Internal(it)))) } - return candidates.map { GenerateContentResponse(candidates = listOf(it)) } + return candidates.map { GenerateContentResponse.Internal(candidates = listOf(it)) } } /**