From bbfaa1eab8bb7c1e41771ed30567a7937f35e777 Mon Sep 17 00:00:00 2001 From: Emily Ploszaj Date: Fri, 17 Jan 2025 16:20:26 -0600 Subject: [PATCH 1/9] Migrate VertexAI serialization to be localized --- .../firebase/vertexai/common/APIController.kt | 17 +- .../firebase/vertexai/common/Exceptions.kt | 5 +- .../firebase/vertexai/common/Request.kt | 28 +- .../firebase/vertexai/common/Response.kt | 46 --- .../firebase/vertexai/common/client/Types.kt | 80 ----- .../firebase/vertexai/common/server/Types.kt | 171 --------- .../firebase/vertexai/common/shared/Types.kt | 115 ------ .../vertexai/internal/util/conversions.kt | 334 +----------------- .../firebase/vertexai/type/Candidate.kt | 171 ++++++++- .../google/firebase/vertexai/type/Content.kt | 17 + .../vertexai/type/CountTokensResponse.kt | 13 + .../vertexai/type/FunctionCallingConfig.kt | 20 +- .../vertexai/type/FunctionDeclaration.kt | 12 + .../vertexai/type/GenerateContentResponse.kt | 17 + .../vertexai/type/GenerationConfig.kt | 31 ++ .../firebase/vertexai/type/HarmBlockMethod.kt | 17 + .../vertexai/type/HarmBlockThreshold.kt | 23 ++ .../firebase/vertexai/type/HarmCategory.kt | 38 ++ .../firebase/vertexai/type/HarmProbability.kt | 26 ++ .../firebase/vertexai/type/HarmSeverity.kt | 26 ++ .../com/google/firebase/vertexai/type/Part.kt | 165 ++++++++- .../firebase/vertexai/type/PromptFeedback.kt | 44 ++- .../firebase/vertexai/type/SafetySetting.kt | 18 +- .../google/firebase/vertexai/type/Schema.kt | 26 ++ .../com/google/firebase/vertexai/type/Tool.kt | 13 + .../firebase/vertexai/type/ToolConfig.kt | 31 +- .../com/google/firebase/vertexai/type/Type.kt | 23 ++ .../firebase/vertexai/type/UsageMetadata.kt | 17 +- .../vertexai/GenerativeModelTesting.kt | 20 +- ...{SchemaTests.kt => InternalSchemaTests.kt} | 2 +- .../vertexai/common/APIControllerTests.kt | 35 +- .../vertexai/common/StreamingSnapshotTests.kt | 22 +- .../vertexai/common/UnarySnapshotTests.kt | 40 +-- .../firebase/vertexai/common/util/tests.kt | 34 +- 34 files changed, 850 insertions(+), 847 deletions(-) delete mode 100644 firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Response.kt delete mode 100644 firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/client/Types.kt delete mode 100644 firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/server/Types.kt delete mode 100644 firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/shared/Types.kt rename firebase-vertexai/src/test/java/com/google/firebase/vertexai/{SchemaTests.kt => InternalSchemaTests.kt} (99%) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt index cbcc0c69e97..12ac704919e 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt @@ -19,12 +19,15 @@ package com.google.firebase.vertexai.common import android.util.Log import com.google.firebase.Firebase import com.google.firebase.options -import com.google.firebase.vertexai.common.server.FinishReason import com.google.firebase.vertexai.common.server.GRpcError import com.google.firebase.vertexai.common.server.GRpcErrorDetails import com.google.firebase.vertexai.common.util.decodeToFlow import com.google.firebase.vertexai.common.util.fullModelName +import com.google.firebase.vertexai.type.GRpcErrorResponse +import com.google.firebase.vertexai.type.InternalCountTokensResponse +import com.google.firebase.vertexai.type.InternalGenerateContentResponse import com.google.firebase.vertexai.type.RequestOptions +import com.google.firebase.vertexai.type.Response import io.ktor.client.HttpClient import io.ktor.client.call.body import io.ktor.client.engine.HttpClientEngine @@ -106,7 +109,7 @@ internal constructor( install(ContentNegotiation) { json(JSON) } } - suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse = + suspend fun generateContent(request: GenerateContentRequest): InternalGenerateContentResponse = try { client .post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:generateContent") { @@ -114,15 +117,15 @@ internal constructor( applyHeaderProvider() } .also { validateResponse(it) } - .body() + .body() .validate() } catch (e: Throwable) { throw FirebaseCommonAIException.from(e) } - fun generateContentStream(request: GenerateContentRequest): Flow = + fun generateContentStream(request: GenerateContentRequest): Flow = client - .postStream( + .postStream( "${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse" ) { applyCommonConfiguration(request) @@ -130,7 +133,7 @@ internal constructor( .map { it.validate() } .catch { throw FirebaseCommonAIException.from(it) } - suspend fun countTokens(request: CountTokensRequest): CountTokensResponse = + suspend fun countTokens(request: CountTokensRequest): InternalCountTokensResponse = try { client .post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:countTokens") { @@ -281,7 +284,7 @@ private fun getServiceDisabledErrorDetailsOrNull(error: GRpcError): GRpcErrorDet } } -private fun GenerateContentResponse.validate() = apply { +private fun InternalGenerateContentResponse.validate() = apply { if ((candidates?.isEmpty() != false) && promptFeedback == null) { throw SerializationException("Error deserializing response, found no valid fields") } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt index 41954b13497..621debdf2c7 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt @@ -16,6 +16,7 @@ package com.google.firebase.vertexai.common +import com.google.firebase.vertexai.type.GenerateContentResponse import io.ktor.serialization.JsonConvertException import kotlinx.coroutines.TimeoutCancellationException @@ -66,7 +67,7 @@ internal class InvalidAPIKeyException(message: String, cause: Throwable? = null) * @property response the full server response for the request. */ internal class PromptBlockedException( - val response: GenerateContentResponse, + val response: GenerateContentResponse.InternalGenerateContentResponse, cause: Throwable? = null ) : FirebaseCommonAIException( @@ -98,7 +99,7 @@ internal class InvalidStateException(message: String, cause: Throwable? = null) * @property response the full server response for the request */ internal class ResponseStoppedException( - val response: GenerateContentResponse, + val response: GenerateContentResponse.InternalGenerateContentResponse, cause: Throwable? = null ) : FirebaseCommonAIException( diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt index 39adea5629c..380526b552a 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt @@ -16,12 +16,12 @@ package com.google.firebase.vertexai.common -import com.google.firebase.vertexai.common.client.GenerationConfig -import com.google.firebase.vertexai.common.client.Tool -import com.google.firebase.vertexai.common.client.ToolConfig -import com.google.firebase.vertexai.common.shared.Content -import com.google.firebase.vertexai.common.shared.SafetySetting import com.google.firebase.vertexai.common.util.fullModelName +import com.google.firebase.vertexai.type.Content +import com.google.firebase.vertexai.type.GenerationConfig +import com.google.firebase.vertexai.type.SafetySetting +import com.google.firebase.vertexai.type.Tool +import com.google.firebase.vertexai.type.ToolConfig import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -30,21 +30,21 @@ internal sealed interface Request @Serializable internal data class GenerateContentRequest( val model: String? = null, - val contents: List, - @SerialName("safety_settings") val safetySettings: List? = null, - @SerialName("generation_config") val generationConfig: GenerationConfig? = null, - val tools: List? = null, - @SerialName("tool_config") var toolConfig: ToolConfig? = null, - @SerialName("system_instruction") val systemInstruction: Content? = null, + val contents: List, + @SerialName("safety_settings") val safetySettings: List? = null, + @SerialName("generation_config") val generationConfig: GenerationConfig.InternalGenerationConfig? = null, + val tools: List? = null, + @SerialName("tool_config") var toolConfig: ToolConfig.InternalToolConfig? = null, + @SerialName("system_instruction") val systemInstruction: Content.InternalContent? = null, ) : Request @Serializable internal data class CountTokensRequest( val generateContentRequest: GenerateContentRequest? = null, val model: String? = null, - val contents: List? = null, - val tools: List? = null, - @SerialName("system_instruction") val systemInstruction: Content? = null, + val contents: List? = null, + val tools: List? = null, + @SerialName("system_instruction") val systemInstruction: Content.InternalContent? = null, ) : Request { companion object { fun forGenAI(generateContentRequest: GenerateContentRequest) = diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Response.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Response.kt deleted file mode 100644 index d8182883442..00000000000 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Response.kt +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.firebase.vertexai.common - -import com.google.firebase.vertexai.common.server.Candidate -import com.google.firebase.vertexai.common.server.GRpcError -import com.google.firebase.vertexai.common.server.PromptFeedback -import kotlinx.serialization.Serializable - -internal sealed interface Response - -@Serializable -internal data class GenerateContentResponse( - val candidates: List? = null, - val promptFeedback: PromptFeedback? = null, - val usageMetadata: UsageMetadata? = null, -) : Response - -@Serializable -internal data class CountTokensResponse( - val totalTokens: Int, - val totalBillableCharacters: Int? = null -) : Response - -@Serializable internal data class GRpcErrorResponse(val error: GRpcError) : Response - -@Serializable -internal data class UsageMetadata( - val promptTokenCount: Int? = null, - val candidatesTokenCount: Int? = null, - val totalTokenCount: Int? = null, -) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/client/Types.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/client/Types.kt deleted file mode 100644 index b950aa3c5f2..00000000000 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/client/Types.kt +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.firebase.vertexai.common.client - -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable -import kotlinx.serialization.json.JsonObject - -@Serializable -internal data class GenerationConfig( - val temperature: Float?, - @SerialName("top_p") val topP: Float?, - @SerialName("top_k") val topK: Int?, - @SerialName("candidate_count") val candidateCount: Int?, - @SerialName("max_output_tokens") val maxOutputTokens: Int?, - @SerialName("stop_sequences") val stopSequences: List?, - @SerialName("response_mime_type") val responseMimeType: String? = null, - @SerialName("presence_penalty") val presencePenalty: Float? = null, - @SerialName("frequency_penalty") val frequencyPenalty: Float? = null, - @SerialName("response_schema") val responseSchema: Schema? = null, -) - -@Serializable -internal data class Tool( - val functionDeclarations: List? = null, - // This is a json object because it is not possible to make a data class with no parameters. - val codeExecution: JsonObject? = null, -) - -@Serializable -internal data class ToolConfig( - @SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig? -) - -@Serializable -internal data class FunctionCallingConfig( - val mode: Mode, - @SerialName("allowed_function_names") val allowedFunctionNames: List? = null -) { - @Serializable - enum class Mode { - @SerialName("MODE_UNSPECIFIED") UNSPECIFIED, - AUTO, - ANY, - NONE - } -} - -@Serializable -internal data class FunctionDeclaration( - val name: String, - val description: String, - val parameters: Schema -) - -@Serializable -internal data class Schema( - val type: String, - val description: String? = null, - val format: String? = null, - val nullable: Boolean? = false, - val enum: List? = null, - val properties: Map? = null, - val required: List? = null, - val items: Schema? = null, -) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/server/Types.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/server/Types.kt deleted file mode 100644 index 3749d534e47..00000000000 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/server/Types.kt +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.firebase.vertexai.common.server - -import com.google.firebase.vertexai.common.shared.Content -import com.google.firebase.vertexai.common.shared.HarmCategory -import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.KSerializer -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable -import kotlinx.serialization.json.JsonNames - -internal object BlockReasonSerializer : - KSerializer by FirstOrdinalSerializer(BlockReason::class) - -internal object HarmProbabilitySerializer : - KSerializer by FirstOrdinalSerializer(HarmProbability::class) - -internal object HarmSeveritySerializer : - KSerializer by FirstOrdinalSerializer(HarmSeverity::class) - -internal object FinishReasonSerializer : - KSerializer by FirstOrdinalSerializer(FinishReason::class) - -@Serializable -internal data class PromptFeedback( - val blockReason: BlockReason? = null, - val safetyRatings: List? = null, - val blockReasonMessage: String? = null, -) - -@Serializable(BlockReasonSerializer::class) -internal enum class BlockReason { - UNKNOWN, - @SerialName("BLOCKED_REASON_UNSPECIFIED") UNSPECIFIED, - SAFETY, - OTHER -} - -@Serializable -internal data class Candidate( - val content: Content? = null, - val finishReason: FinishReason? = null, - val safetyRatings: List? = null, - val citationMetadata: CitationMetadata? = null, - val groundingMetadata: GroundingMetadata? = null, -) - -@Serializable -internal data class CitationMetadata -@OptIn(ExperimentalSerializationApi::class) -internal constructor(@JsonNames("citations") val citationSources: List) - -@Serializable -internal data class CitationSources( - val title: String? = null, - val startIndex: Int = 0, - val endIndex: Int, - val uri: String? = null, - val license: String? = null, - val publicationDate: Date? = null, -) - -@Serializable -internal data class Date( - /** Year of the date. Must be between 1 and 9999, or 0 for no year. */ - val year: Int? = null, - /** 1-based index for month. Must be from 1 to 12, or 0 to specify a year without a month. */ - val month: Int? = null, - /** - * Day of a month. Must be from 1 to 31 and valid for the year and month, or 0 to specify a year - * by itself or a year and month where the day isn't significant. - */ - val day: Int? = null, -) - -@Serializable -internal data class SafetyRating( - val category: HarmCategory, - val probability: HarmProbability, - val blocked: Boolean? = null, // TODO(): any reason not to default to false? - val probabilityScore: Float? = null, - val severity: HarmSeverity? = null, - val severityScore: Float? = null, -) - -@Serializable -internal data class GroundingMetadata( - @SerialName("web_search_queries") val webSearchQueries: List?, - @SerialName("search_entry_point") val searchEntryPoint: SearchEntryPoint?, - @SerialName("retrieval_queries") val retrievalQueries: List?, - @SerialName("grounding_attribution") val groundingAttribution: List?, -) - -@Serializable -internal data class SearchEntryPoint( - @SerialName("rendered_content") val renderedContent: String?, - @SerialName("sdk_blob") val sdkBlob: String?, -) - -@Serializable -internal data class GroundingAttribution( - val segment: Segment, - @SerialName("confidence_score") val confidenceScore: Float?, -) - -@Serializable -internal data class Segment( - @SerialName("start_index") val startIndex: Int, - @SerialName("end_index") val endIndex: Int, -) - -@Serializable(HarmProbabilitySerializer::class) -internal enum class HarmProbability { - UNKNOWN, - @SerialName("HARM_PROBABILITY_UNSPECIFIED") UNSPECIFIED, - NEGLIGIBLE, - LOW, - MEDIUM, - HIGH -} - -@Serializable(HarmSeveritySerializer::class) -internal enum class HarmSeverity { - UNKNOWN, - @SerialName("HARM_SEVERITY_UNSPECIFIED") UNSPECIFIED, - @SerialName("HARM_SEVERITY_NEGLIGIBLE") NEGLIGIBLE, - @SerialName("HARM_SEVERITY_LOW") LOW, - @SerialName("HARM_SEVERITY_MEDIUM") MEDIUM, - @SerialName("HARM_SEVERITY_HIGH") HIGH -} - -@Serializable(FinishReasonSerializer::class) -internal enum class FinishReason { - UNKNOWN, - @SerialName("FINISH_REASON_UNSPECIFIED") UNSPECIFIED, - STOP, - MAX_TOKENS, - SAFETY, - RECITATION, - OTHER -} - -@Serializable -internal data class GRpcError( - val code: Int, - val message: String, - val details: List? = null -) - -@Serializable -internal data class GRpcErrorDetails( - val reason: String? = null, - val domain: String? = null, - val metadata: Map? = null -) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/shared/Types.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/shared/Types.kt deleted file mode 100644 index b32772e995c..00000000000 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/shared/Types.kt +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.firebase.vertexai.common.shared - -import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer -import kotlinx.serialization.DeserializationStrategy -import kotlinx.serialization.EncodeDefault -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.KSerializer -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable -import kotlinx.serialization.SerializationException -import kotlinx.serialization.json.JsonContentPolymorphicSerializer -import kotlinx.serialization.json.JsonElement -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.jsonObject - -internal object HarmCategorySerializer : - KSerializer by FirstOrdinalSerializer(HarmCategory::class) - -@Serializable(HarmCategorySerializer::class) -internal enum class HarmCategory { - UNKNOWN, - @SerialName("HARM_CATEGORY_HARASSMENT") HARASSMENT, - @SerialName("HARM_CATEGORY_HATE_SPEECH") HATE_SPEECH, - @SerialName("HARM_CATEGORY_SEXUALLY_EXPLICIT") SEXUALLY_EXPLICIT, - @SerialName("HARM_CATEGORY_DANGEROUS_CONTENT") DANGEROUS_CONTENT, - @SerialName("HARM_CATEGORY_CIVIC_INTEGRITY") CIVIC_INTEGRITY, -} - -internal typealias Base64 = String - -@ExperimentalSerializationApi -@Serializable -internal data class Content(@EncodeDefault val role: String? = "user", val parts: List) - -@Serializable(PartSerializer::class) internal sealed interface Part - -@Serializable internal data class TextPart(val text: String) : Part - -@Serializable -internal data class InlineDataPart(@SerialName("inline_data") val inlineData: InlineData) : Part - -@Serializable internal data class FunctionCallPart(val functionCall: FunctionCall) : Part - -@Serializable -internal data class FunctionResponsePart(val functionResponse: FunctionResponse) : Part - -@Serializable internal data class FunctionResponse(val name: String, val response: JsonObject) - -@Serializable -internal data class FunctionCall(val name: String, val args: Map? = null) - -@Serializable -internal data class FileDataPart(@SerialName("file_data") val fileData: FileData) : Part - -@Serializable -internal data class FileData( - @SerialName("mime_type") val mimeType: String, - @SerialName("file_uri") val fileUri: String, -) - -@Serializable -internal data class InlineData(@SerialName("mime_type") val mimeType: String, val data: Base64) - -@Serializable -internal data class SafetySetting( - val category: HarmCategory, - val threshold: HarmBlockThreshold, - val method: HarmBlockMethod? = null, -) - -@Serializable -internal enum class HarmBlockThreshold { - @SerialName("HARM_BLOCK_THRESHOLD_UNSPECIFIED") UNSPECIFIED, - BLOCK_LOW_AND_ABOVE, - BLOCK_MEDIUM_AND_ABOVE, - BLOCK_ONLY_HIGH, - BLOCK_NONE, -} - -@Serializable -internal enum class HarmBlockMethod { - @SerialName("HARM_BLOCK_METHOD_UNSPECIFIED") UNSPECIFIED, - SEVERITY, - PROBABILITY, -} - -internal object PartSerializer : JsonContentPolymorphicSerializer(Part::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { - val jsonObject = element.jsonObject - return when { - "text" in jsonObject -> TextPart.serializer() - "functionCall" in jsonObject -> FunctionCallPart.serializer() - "functionResponse" in jsonObject -> FunctionResponsePart.serializer() - "inlineData" in jsonObject -> InlineDataPart.serializer() - "fileData" in jsonObject -> FileDataPart.serializer() - else -> throw SerializationException("Unknown Part type") - } - } -} diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt index f8388054260..e24b7f08f4e 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt @@ -16,92 +16,13 @@ package com.google.firebase.vertexai.internal.util -import android.graphics.Bitmap -import android.graphics.BitmapFactory import android.util.Base64 -import com.google.firebase.vertexai.common.client.Schema -import com.google.firebase.vertexai.common.shared.FileData -import com.google.firebase.vertexai.common.shared.FunctionCall -import com.google.firebase.vertexai.common.shared.FunctionCallPart -import com.google.firebase.vertexai.common.shared.FunctionResponse -import com.google.firebase.vertexai.common.shared.FunctionResponsePart -import com.google.firebase.vertexai.common.shared.InlineData -import com.google.firebase.vertexai.type.BlockReason -import com.google.firebase.vertexai.type.Candidate -import com.google.firebase.vertexai.type.Citation -import com.google.firebase.vertexai.type.CitationMetadata -import com.google.firebase.vertexai.type.Content -import com.google.firebase.vertexai.type.CountTokensResponse -import com.google.firebase.vertexai.type.FileDataPart -import com.google.firebase.vertexai.type.FinishReason -import com.google.firebase.vertexai.type.FunctionCallingConfig -import com.google.firebase.vertexai.type.FunctionDeclaration -import com.google.firebase.vertexai.type.GenerateContentResponse -import com.google.firebase.vertexai.type.GenerationConfig -import com.google.firebase.vertexai.type.HarmBlockMethod -import com.google.firebase.vertexai.type.HarmBlockThreshold -import com.google.firebase.vertexai.type.HarmCategory -import com.google.firebase.vertexai.type.HarmProbability -import com.google.firebase.vertexai.type.HarmSeverity -import com.google.firebase.vertexai.type.ImagePart -import com.google.firebase.vertexai.type.InlineDataPart -import com.google.firebase.vertexai.type.Part -import com.google.firebase.vertexai.type.PromptFeedback -import com.google.firebase.vertexai.type.SafetyRating -import com.google.firebase.vertexai.type.SafetySetting import com.google.firebase.vertexai.type.SerializationException -import com.google.firebase.vertexai.type.TextPart -import com.google.firebase.vertexai.type.Tool -import com.google.firebase.vertexai.type.ToolConfig -import com.google.firebase.vertexai.type.UsageMetadata -import com.google.firebase.vertexai.type.content -import java.io.ByteArrayOutputStream -import java.util.Calendar import kotlinx.serialization.json.Json -import kotlinx.serialization.json.JsonNull import kotlinx.serialization.json.JsonObject import org.json.JSONObject -private const val BASE_64_FLAGS = Base64.NO_WRAP - -internal fun Content.toInternal() = - com.google.firebase.vertexai.common.shared.Content( - this.role ?: "user", - this.parts.map { it.toInternal() } - ) - -internal fun Part.toInternal(): com.google.firebase.vertexai.common.shared.Part { - return when (this) { - is TextPart -> com.google.firebase.vertexai.common.shared.TextPart(text) - is ImagePart -> - com.google.firebase.vertexai.common.shared.InlineDataPart( - InlineData("image/jpeg", encodeBitmapToBase64Png(image)) - ) - is InlineDataPart -> - com.google.firebase.vertexai.common.shared.InlineDataPart( - InlineData(mimeType, Base64.encodeToString(inlineData, BASE_64_FLAGS)) - ) - is com.google.firebase.vertexai.type.FunctionCallPart -> - FunctionCallPart(FunctionCall(name, args)) - is com.google.firebase.vertexai.type.FunctionResponsePart -> - FunctionResponsePart(FunctionResponse(name, response)) - is FileDataPart -> - com.google.firebase.vertexai.common.shared.FileDataPart( - FileData(mimeType = mimeType, fileUri = uri) - ) - else -> - throw SerializationException( - "The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet." - ) - } -} - -internal fun SafetySetting.toInternal() = - com.google.firebase.vertexai.common.shared.SafetySetting( - harmCategory.toInternal(), - threshold.toInternal(), - method?.toInternal() - ) +internal const val BASE_64_FLAGS = Base64.NO_WRAP internal fun makeMissingCaseException(source: String, ordinal: Int): SerializationException { return SerializationException( @@ -115,259 +36,6 @@ internal fun makeMissingCaseException(source: String, ordinal: Int): Serializati ) } -internal fun GenerationConfig.toInternal() = - com.google.firebase.vertexai.common.client.GenerationConfig( - temperature = temperature, - topP = topP, - topK = topK, - candidateCount = candidateCount, - maxOutputTokens = maxOutputTokens, - stopSequences = stopSequences, - frequencyPenalty = frequencyPenalty, - presencePenalty = presencePenalty, - responseMimeType = responseMimeType, - responseSchema = responseSchema?.toInternal() - ) - -internal fun HarmCategory.toInternal() = - when (this) { - HarmCategory.HARASSMENT -> com.google.firebase.vertexai.common.shared.HarmCategory.HARASSMENT - HarmCategory.HATE_SPEECH -> com.google.firebase.vertexai.common.shared.HarmCategory.HATE_SPEECH - HarmCategory.SEXUALLY_EXPLICIT -> - com.google.firebase.vertexai.common.shared.HarmCategory.SEXUALLY_EXPLICIT - HarmCategory.DANGEROUS_CONTENT -> - com.google.firebase.vertexai.common.shared.HarmCategory.DANGEROUS_CONTENT - HarmCategory.CIVIC_INTEGRITY -> - com.google.firebase.vertexai.common.shared.HarmCategory.CIVIC_INTEGRITY - HarmCategory.UNKNOWN -> com.google.firebase.vertexai.common.shared.HarmCategory.UNKNOWN - else -> throw makeMissingCaseException("HarmCategory", ordinal) - } - -internal fun HarmBlockMethod.toInternal() = - when (this) { - HarmBlockMethod.SEVERITY -> com.google.firebase.vertexai.common.shared.HarmBlockMethod.SEVERITY - HarmBlockMethod.PROBABILITY -> - com.google.firebase.vertexai.common.shared.HarmBlockMethod.PROBABILITY - else -> throw makeMissingCaseException("HarmBlockMethod", ordinal) - } - -internal fun ToolConfig.toInternal() = - com.google.firebase.vertexai.common.client.ToolConfig( - functionCallingConfig?.let { - com.google.firebase.vertexai.common.client.FunctionCallingConfig( - when (it.mode) { - FunctionCallingConfig.Mode.ANY -> - com.google.firebase.vertexai.common.client.FunctionCallingConfig.Mode.ANY - FunctionCallingConfig.Mode.AUTO -> - com.google.firebase.vertexai.common.client.FunctionCallingConfig.Mode.AUTO - FunctionCallingConfig.Mode.NONE -> - com.google.firebase.vertexai.common.client.FunctionCallingConfig.Mode.NONE - }, - it.allowedFunctionNames - ) - } - ) - -internal fun HarmBlockThreshold.toInternal() = - when (this) { - HarmBlockThreshold.NONE -> - com.google.firebase.vertexai.common.shared.HarmBlockThreshold.BLOCK_NONE - HarmBlockThreshold.ONLY_HIGH -> - com.google.firebase.vertexai.common.shared.HarmBlockThreshold.BLOCK_ONLY_HIGH - HarmBlockThreshold.MEDIUM_AND_ABOVE -> - com.google.firebase.vertexai.common.shared.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE - HarmBlockThreshold.LOW_AND_ABOVE -> - com.google.firebase.vertexai.common.shared.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE - else -> throw makeMissingCaseException("HarmBlockThreshold", ordinal) - } - -internal fun Tool.toInternal() = - com.google.firebase.vertexai.common.client.Tool( - functionDeclarations?.map { it.toInternal() } ?: emptyList() - ) - -internal fun FunctionDeclaration.toInternal() = - com.google.firebase.vertexai.common.client.FunctionDeclaration(name, "", schema.toInternal()) - -internal fun com.google.firebase.vertexai.type.Schema.toInternal(): Schema = - Schema( - type, - description, - format, - nullable, - enum, - properties?.mapValues { it.value.toInternal() }, - required, - items?.toInternal(), - ) - internal fun JSONObject.toInternal() = Json.decodeFromString(toString()) -internal fun com.google.firebase.vertexai.common.server.Candidate.toPublic(): Candidate { - val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() - val citations = citationMetadata?.toPublic() - val finishReason = finishReason.toPublic() - - return Candidate( - this.content?.toPublic() ?: content("model") {}, - safetyRatings, - citations, - finishReason - ) -} - -internal fun com.google.firebase.vertexai.common.UsageMetadata.toPublic(): UsageMetadata = - UsageMetadata(promptTokenCount ?: 0, candidatesTokenCount ?: 0, totalTokenCount ?: 0) - -internal fun com.google.firebase.vertexai.common.shared.Content.toPublic(): Content = - Content(role, parts.map { it.toPublic() }) - -internal fun com.google.firebase.vertexai.common.shared.Part.toPublic(): Part { - return when (this) { - is com.google.firebase.vertexai.common.shared.TextPart -> TextPart(text) - is com.google.firebase.vertexai.common.shared.InlineDataPart -> { - val data = Base64.decode(inlineData.data, BASE_64_FLAGS) - if (inlineData.mimeType.contains("image")) { - ImagePart(decodeBitmapFromImage(data)) - } else { - InlineDataPart(data, inlineData.mimeType) - } - } - is FunctionCallPart -> - com.google.firebase.vertexai.type.FunctionCallPart( - functionCall.name, - functionCall.args.orEmpty().mapValues { it.value ?: JsonNull } - ) - is FunctionResponsePart -> - com.google.firebase.vertexai.type.FunctionResponsePart( - functionResponse.name, - functionResponse.response, - ) - is com.google.firebase.vertexai.common.shared.FileDataPart -> - FileDataPart(fileData.mimeType, fileData.fileUri) - else -> - throw SerializationException( - "Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK." - ) - } -} - -internal fun com.google.firebase.vertexai.common.server.CitationSources.toPublic(): Citation { - val publicationDateAsCalendar = - publicationDate?.let { - val calendar = Calendar.getInstance() - // Internal `Date.year` uses 0 to represent not specified. We use 1 as default. - val year = if (it.year == null || it.year < 1) 1 else it.year - // Internal `Date.month` uses 0 to represent not specified, or is 1-12 as months. The month as - // expected by [Calendar] is 0-based, so we subtract 1 or use 0 as default. - val month = if (it.month == null || it.month < 1) 0 else it.month - 1 - // Internal `Date.day` uses 0 to represent not specified. We use 1 as default. - val day = if (it.day == null || it.day < 1) 1 else it.day - calendar.set(year, month, day) - calendar - } - return Citation( - title = title, - startIndex = startIndex, - endIndex = endIndex, - uri = uri, - license = license, - publicationDate = publicationDateAsCalendar - ) -} - -internal fun com.google.firebase.vertexai.common.server.CitationMetadata.toPublic() = - CitationMetadata(citationSources.map { it.toPublic() }) - -internal fun com.google.firebase.vertexai.common.server.SafetyRating.toPublic() = - SafetyRating( - category = category.toPublic(), - probability = probability.toPublic(), - probabilityScore = probabilityScore ?: 0f, - blocked = blocked, - severity = severity?.toPublic(), - severityScore = severityScore - ) - -internal fun com.google.firebase.vertexai.common.server.PromptFeedback.toPublic(): PromptFeedback { - val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() - return com.google.firebase.vertexai.type.PromptFeedback( - blockReason?.toPublic(), - safetyRatings, - blockReasonMessage - ) -} - -internal fun com.google.firebase.vertexai.common.server.FinishReason?.toPublic() = - when (this) { - null -> null - com.google.firebase.vertexai.common.server.FinishReason.MAX_TOKENS -> FinishReason.MAX_TOKENS - com.google.firebase.vertexai.common.server.FinishReason.RECITATION -> FinishReason.RECITATION - com.google.firebase.vertexai.common.server.FinishReason.SAFETY -> FinishReason.SAFETY - com.google.firebase.vertexai.common.server.FinishReason.STOP -> FinishReason.STOP - com.google.firebase.vertexai.common.server.FinishReason.OTHER -> FinishReason.OTHER - else -> FinishReason.UNKNOWN - } - -internal fun com.google.firebase.vertexai.common.shared.HarmCategory.toPublic() = - when (this) { - com.google.firebase.vertexai.common.shared.HarmCategory.HARASSMENT -> HarmCategory.HARASSMENT - com.google.firebase.vertexai.common.shared.HarmCategory.HATE_SPEECH -> HarmCategory.HATE_SPEECH - com.google.firebase.vertexai.common.shared.HarmCategory.SEXUALLY_EXPLICIT -> - HarmCategory.SEXUALLY_EXPLICIT - com.google.firebase.vertexai.common.shared.HarmCategory.DANGEROUS_CONTENT -> - HarmCategory.DANGEROUS_CONTENT - com.google.firebase.vertexai.common.shared.HarmCategory.CIVIC_INTEGRITY -> - HarmCategory.CIVIC_INTEGRITY - else -> HarmCategory.UNKNOWN - } - -internal fun com.google.firebase.vertexai.common.server.HarmProbability.toPublic() = - when (this) { - com.google.firebase.vertexai.common.server.HarmProbability.HIGH -> HarmProbability.HIGH - com.google.firebase.vertexai.common.server.HarmProbability.MEDIUM -> HarmProbability.MEDIUM - com.google.firebase.vertexai.common.server.HarmProbability.LOW -> HarmProbability.LOW - com.google.firebase.vertexai.common.server.HarmProbability.NEGLIGIBLE -> - HarmProbability.NEGLIGIBLE - else -> HarmProbability.UNKNOWN - } - -internal fun com.google.firebase.vertexai.common.server.HarmSeverity.toPublic() = - when (this) { - com.google.firebase.vertexai.common.server.HarmSeverity.HIGH -> HarmSeverity.HIGH - com.google.firebase.vertexai.common.server.HarmSeverity.MEDIUM -> HarmSeverity.MEDIUM - com.google.firebase.vertexai.common.server.HarmSeverity.LOW -> HarmSeverity.LOW - com.google.firebase.vertexai.common.server.HarmSeverity.NEGLIGIBLE -> HarmSeverity.NEGLIGIBLE - else -> HarmSeverity.UNKNOWN - } - -internal fun com.google.firebase.vertexai.common.server.BlockReason.toPublic() = - when (this) { - com.google.firebase.vertexai.common.server.BlockReason.SAFETY -> BlockReason.SAFETY - com.google.firebase.vertexai.common.server.BlockReason.OTHER -> BlockReason.OTHER - else -> BlockReason.UNKNOWN - } - -internal fun com.google.firebase.vertexai.common.GenerateContentResponse.toPublic(): - GenerateContentResponse { - return GenerateContentResponse( - candidates?.map { it.toPublic() }.orEmpty(), - promptFeedback?.toPublic(), - usageMetadata?.toPublic() - ) -} - -internal fun com.google.firebase.vertexai.common.CountTokensResponse.toPublic() = - CountTokensResponse(totalTokens, totalBillableCharacters ?: 0) - internal fun JsonObject.toPublic() = JSONObject(toString()) - -private fun encodeBitmapToBase64Png(input: Bitmap): String { - ByteArrayOutputStream().let { - input.compress(Bitmap.CompressFormat.JPEG, 80, it) - return Base64.encodeToString(it.toByteArray(), BASE_64_FLAGS) - } -} - -private fun decodeBitmapFromImage(input: ByteArray) = - BitmapFactory.decodeByteArray(input, 0, input.size) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt index 6d8d96eb047..2962edbcb65 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt @@ -16,6 +16,12 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonNames import java.util.Calendar /** @@ -32,7 +38,58 @@ internal constructor( public val safetyRatings: List, public val citationMetadata: CitationMetadata?, public val finishReason: FinishReason? -) +) { + + @Serializable + internal data class InternalCandidate( + val content: Content.InternalContent? = null, + val finishReason: FinishReason.InternalFinishReason? = null, + val safetyRatings: List? = null, + val citationMetadata: CitationMetadata.InternalCitationMetadata? = null, + val groundingMetadata: InternalGroundingMetadata? = null, + ) { + internal fun toPublic(): Candidate { + val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() + val citations = citationMetadata?.toPublic() + val finishReason = finishReason?.toPublic() + + return Candidate( + this.content?.toPublic() ?: content("model") {}, + safetyRatings, + citations, + finishReason + ) + } + + @Serializable + internal data class InternalGroundingMetadata( + @SerialName("web_search_queries") val webSearchQueries: List?, + @SerialName("search_entry_point") val searchEntryPoint: InternalSearchEntryPoint?, + @SerialName("retrieval_queries") val retrievalQueries: List?, + @SerialName("grounding_attribution") val groundingAttribution: List?, + ) { + + @Serializable + internal data class InternalSearchEntryPoint( + @SerialName("rendered_content") val renderedContent: String?, + @SerialName("sdk_blob") val sdkBlob: String?, + ) + + @Serializable + internal data class InternalGroundingAttribution( + val segment: InternalSegment, + @SerialName("confidence_score") val confidenceScore: Float?, + ) { + + @Serializable + internal data class InternalSegment( + @SerialName("start_index") val startIndex: Int, + @SerialName("end_index") val endIndex: Int, + ) + } + } + } +} /** * An assessment of the potential harm of some generated content. @@ -55,7 +112,30 @@ internal constructor( public val blocked: Boolean? = null, public val severity: HarmSeverity? = null, public val severityScore: Float? = null -) +) { + + @Serializable + internal data class InternalSafetyRating @JvmOverloads constructor( + val category: HarmCategory.InternalHarmCategory, + val probability: HarmProbability.InternalHarmProbability, + val blocked: Boolean? = null, // TODO(): any reason not to default to false? + val probabilityScore: Float? = null, + val severity: HarmSeverity.InternalHarmSeverity? = null, + val severityScore: Float? = null, + ) { + + internal fun toPublic() = + SafetyRating( + category = category.toPublic(), + probability = probability.toPublic(), + probabilityScore = probabilityScore ?: 0f, + blocked = blocked, + severity = severity?.toPublic(), + severityScore = severityScore + ) + } + +} /** * A collection of source attributions for a piece of content. @@ -63,7 +143,17 @@ internal constructor( * @property citations A list of individual cited sources and the parts of the content to which they * apply. */ -public class CitationMetadata internal constructor(public val citations: List) +public class CitationMetadata internal constructor(public val citations: List) { + + @Serializable + internal data class InternalCitationMetadata + @OptIn(ExperimentalSerializationApi::class) + internal constructor(@JsonNames("citations") val citationSources: List) { + + internal fun toPublic() = + CitationMetadata(citationSources.map { it.toPublic() }) + } +} /** * Represents a citation of content from an external source within the model's output. @@ -89,7 +179,56 @@ internal constructor( public val uri: String? = null, public val license: String? = null, public val publicationDate: Calendar? = null -) +) { + + @Serializable + internal data class InternalCitationSources( + val title: String? = null, + val startIndex: Int = 0, + val endIndex: Int, + val uri: String? = null, + val license: String? = null, + val publicationDate: InternalDate? = null, + ) { + + internal fun toPublic(): Citation { + val publicationDateAsCalendar = + publicationDate?.let { + val calendar = Calendar.getInstance() + // Internal `Date.year` uses 0 to represent not specified. We use 1 as default. + val year = if (it.year == null || it.year < 1) 1 else it.year + // Internal `Date.month` uses 0 to represent not specified, or is 1-12 as months. The month as + // expected by [Calendar] is 0-based, so we subtract 1 or use 0 as default. + val month = if (it.month == null || it.month < 1) 0 else it.month - 1 + // Internal `Date.day` uses 0 to represent not specified. We use 1 as default. + val day = if (it.day == null || it.day < 1) 1 else it.day + calendar.set(year, month, day) + calendar + } + return Citation( + title = title, + startIndex = startIndex, + endIndex = endIndex, + uri = uri, + license = license, + publicationDate = publicationDateAsCalendar + ) + } + + @Serializable + internal data class InternalDate( + /** Year of the date. Must be between 1 and 9999, or 0 for no year. */ + val year: Int? = null, + /** 1-based index for month. Must be from 1 to 12, or 0 to specify a year without a month. */ + val month: Int? = null, + /** + * Day of a month. Must be from 1 to 31 and valid for the year and month, or 0 to specify a year + * by itself or a year and month where the day isn't significant. + */ + val day: Int? = null, + ) + } +} /** * Represents the reason why the model stopped generating content. @@ -98,6 +237,30 @@ internal constructor( * @property ordinal The ordinal value of the finish reason. */ public class FinishReason private constructor(public val name: String, public val ordinal: Int) { + + @Serializable(InternalFinishReason.InternalFinishReasonSerializer::class) + internal enum class InternalFinishReason { + UNKNOWN, + @SerialName("FINISH_REASON_UNSPECIFIED") UNSPECIFIED, + STOP, + MAX_TOKENS, + SAFETY, + RECITATION, + OTHER; + + internal object InternalFinishReasonSerializer : + KSerializer by FirstOrdinalSerializer(InternalFinishReason::class) + + internal fun toPublic() = + when (this) { + MAX_TOKENS -> FinishReason.MAX_TOKENS + RECITATION -> FinishReason.RECITATION + SAFETY -> FinishReason.SAFETY + STOP -> FinishReason.STOP + OTHER -> FinishReason.OTHER + else -> FinishReason.UNKNOWN + } + } public companion object { /** A new and not yet supported value. */ @JvmField public val UNKNOWN: FinishReason = FinishReason("UNKNOWN", 0) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt index ec3e9555741..ef9b455784f 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt @@ -17,6 +17,9 @@ package com.google.firebase.vertexai.type import android.graphics.Bitmap +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.Serializable /** * Represents content sent to and received from the model. @@ -76,6 +79,20 @@ constructor(public val role: String? = "user", public val parts: List) { /** Returns a new [Content] using the defined [role] and [parts]. */ public fun build(): Content = Content(role, parts) } + + internal fun toInternal() = + InternalContent( + this.role ?: "user", + this.parts.map { it.toInternal() } + ) + + @ExperimentalSerializationApi + @Serializable + internal data class InternalContent(@EncodeDefault val role: String? = "user", val parts: List) { + + internal fun toPublic(): Content = + Content(role, parts.map { it.toPublic() }) + } } /** diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt index 2835deba6f7..47823ba4e88 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt @@ -16,6 +16,8 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable + /** * The model's response to a count tokens request. * @@ -36,4 +38,15 @@ public class CountTokensResponse( public operator fun component1(): Int = totalTokens public operator fun component2(): Int? = totalBillableCharacters + + @Serializable + internal data class InternalCountTokensResponse( + val totalTokens: Int, + val totalBillableCharacters: Int? = null + ) : Response { + + internal fun toPublic(): CountTokensResponse { + return CountTokensResponse(totalTokens, totalBillableCharacters ?: 0) + } + } } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt index a2ea9b1d01e..2dd09e169ff 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt @@ -16,6 +16,9 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** * The configuration that specifies the function calling behavior. * @@ -42,7 +45,22 @@ internal constructor( * The model will never predict a function call to answer a query. This can also be achieved by * not passing any tools to the model. */ - NONE + NONE, + } + + + @Serializable + internal data class InternalFunctionCallingConfig( + val mode: Mode, + @SerialName("allowed_function_names") val allowedFunctionNames: List? = null + ) { + @Serializable + enum class Mode { + @SerialName("MODE_UNSPECIFIED") UNSPECIFIED, + AUTO, + ANY, + NONE, + } } public companion object { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt index 672293bb559..b806a6b1714 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt @@ -16,6 +16,8 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable + /** * Defines a function that the model can use as a tool. * @@ -58,4 +60,14 @@ public class FunctionDeclaration( ) { internal val schema: Schema = Schema.obj(properties = parameters, optionalProperties = optionalParameters, nullable = false) + + internal fun toInternal() = + InternalFunctionDeclaration(name, "", schema.toInternal()) + + @Serializable + internal data class InternalFunctionDeclaration( + val name: String, + val description: String, + val parameters: Schema.InternalSchema + ) } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerateContentResponse.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerateContentResponse.kt index 85891457b78..8c2452e8b6b 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerateContentResponse.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerateContentResponse.kt @@ -16,6 +16,8 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable + /** * A response from the model. * @@ -41,4 +43,19 @@ public class GenerateContentResponse( public val functionCalls: List by lazy { candidates.first().content.parts.filterIsInstance() } + + @Serializable + internal data class InternalGenerateContentResponse( + val candidates: List? = null, + val promptFeedback: PromptFeedback.InternalPromptFeedback? = null, + val usageMetadata: UsageMetadata.InternalUsageMetadata? = null, + ) : Response { + internal fun toPublic(): GenerateContentResponse { + return GenerateContentResponse( + candidates?.map { it.toPublic() }.orEmpty(), + promptFeedback?.toPublic(), + usageMetadata?.toPublic() + ) + } + } } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt index 8bf8d7a1ac7..c76fb8b7524 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt @@ -16,6 +16,9 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** * Configuration parameters to use for content generation. * @@ -142,6 +145,34 @@ private constructor( ) } + internal fun toInternal() = + InternalGenerationConfig( + temperature = temperature, + topP = topP, + topK = topK, + candidateCount = candidateCount, + maxOutputTokens = maxOutputTokens, + stopSequences = stopSequences, + frequencyPenalty = frequencyPenalty, + presencePenalty = presencePenalty, + responseMimeType = responseMimeType, + responseSchema = responseSchema?.toInternal() + ) + + @Serializable + internal data class InternalGenerationConfig( + val temperature: Float?, + @SerialName("top_p") val topP: Float?, + @SerialName("top_k") val topK: Int?, + @SerialName("candidate_count") val candidateCount: Int?, + @SerialName("max_output_tokens") val maxOutputTokens: Int?, + @SerialName("stop_sequences") val stopSequences: List?, + @SerialName("response_mime_type") val responseMimeType: String? = null, + @SerialName("presence_penalty") val presencePenalty: Float? = null, + @SerialName("frequency_penalty") val frequencyPenalty: Float? = null, + @SerialName("response_schema") val responseSchema: Schema.InternalSchema? = null, + ) + public companion object { /** diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt index e743964c64a..f42f63c0d54 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt @@ -16,11 +16,28 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.internal.util.makeMissingCaseException +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** * Specifies how the block method computes the score that will be compared against the * [HarmBlockThreshold] in [SafetySetting]. */ public class HarmBlockMethod private constructor(public val ordinal: Int) { + internal fun toInternal() = + when (this) { + SEVERITY -> InternalHarmBlockMethod.SEVERITY + PROBABILITY -> InternalHarmBlockMethod.PROBABILITY + else -> throw makeMissingCaseException("HarmBlockMethod", ordinal) + } + + @Serializable + internal enum class InternalHarmBlockMethod { + @SerialName("HARM_BLOCK_METHOD_UNSPECIFIED") UNSPECIFIED, + SEVERITY, + PROBABILITY, + } public companion object { /** * The harm block method uses both probability and severity scores. See [HarmSeverity] and diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt index 073416112ab..14307c17899 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt @@ -16,8 +16,31 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.internal.util.makeMissingCaseException +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** Represents the threshold for a [HarmCategory] to be allowed by [SafetySetting]. */ public class HarmBlockThreshold private constructor(public val ordinal: Int) { + + internal fun toInternal() = + when (this) { + NONE -> InternalHarmBlockThreshold.BLOCK_NONE + ONLY_HIGH -> InternalHarmBlockThreshold.BLOCK_ONLY_HIGH + MEDIUM_AND_ABOVE -> InternalHarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE + LOW_AND_ABOVE -> InternalHarmBlockThreshold.BLOCK_LOW_AND_ABOVE + else -> throw makeMissingCaseException("HarmBlockThreshold", ordinal) + } + + @Serializable + internal enum class InternalHarmBlockThreshold { + @SerialName("HARM_BLOCK_THRESHOLD_UNSPECIFIED") UNSPECIFIED, + BLOCK_LOW_AND_ABOVE, + BLOCK_MEDIUM_AND_ABOVE, + BLOCK_ONLY_HIGH, + BLOCK_NONE, + } + public companion object { /** Content with negligible harm is allowed. */ @JvmField public val LOW_AND_ABOVE: HarmBlockThreshold = HarmBlockThreshold(0) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt index d19de2e1568..ab5fdd927c0 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt @@ -16,8 +16,46 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer +import com.google.firebase.vertexai.internal.util.makeMissingCaseException +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** Category for a given harm rating. */ public class HarmCategory private constructor(public val ordinal: Int) { + internal fun toInternal() = + when (this) { + HARASSMENT -> InternalHarmCategory.HARASSMENT + HATE_SPEECH -> InternalHarmCategory.HATE_SPEECH + SEXUALLY_EXPLICIT -> InternalHarmCategory.SEXUALLY_EXPLICIT + DANGEROUS_CONTENT -> InternalHarmCategory.DANGEROUS_CONTENT + CIVIC_INTEGRITY -> InternalHarmCategory.CIVIC_INTEGRITY + UNKNOWN -> InternalHarmCategory.UNKNOWN + else -> throw makeMissingCaseException("HarmCategory", ordinal) + } + @Serializable(InternalHarmCategory.HarmCategorySerializer::class) + internal enum class InternalHarmCategory { + UNKNOWN, + @SerialName("HARM_CATEGORY_HARASSMENT") HARASSMENT, + @SerialName("HARM_CATEGORY_HATE_SPEECH") HATE_SPEECH, + @SerialName("HARM_CATEGORY_SEXUALLY_EXPLICIT") SEXUALLY_EXPLICIT, + @SerialName("HARM_CATEGORY_DANGEROUS_CONTENT") DANGEROUS_CONTENT, + @SerialName("HARM_CATEGORY_CIVIC_INTEGRITY") CIVIC_INTEGRITY; + + internal object HarmCategorySerializer : + KSerializer by FirstOrdinalSerializer(InternalHarmCategory::class) + + internal fun toPublic() = + when (this) { + HARASSMENT -> HarmCategory.HARASSMENT + HATE_SPEECH -> HarmCategory.HATE_SPEECH + SEXUALLY_EXPLICIT -> HarmCategory.SEXUALLY_EXPLICIT + DANGEROUS_CONTENT -> HarmCategory.DANGEROUS_CONTENT + CIVIC_INTEGRITY -> HarmCategory.CIVIC_INTEGRITY + else -> HarmCategory.UNKNOWN + } + } public companion object { /** A new and not yet supported value. */ @JvmField public val UNKNOWN: HarmCategory = HarmCategory(0) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt index d4208f7bf85..19ab5ce524a 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt @@ -16,8 +16,34 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** Represents the probability that some [HarmCategory] is applicable in a [SafetyRating]. */ public class HarmProbability private constructor(public val ordinal: Int) { + @Serializable(InternalHarmProbability.InternalHarmProbabilitySerializer::class) + internal enum class InternalHarmProbability { + UNKNOWN, + @SerialName("HARM_PROBABILITY_UNSPECIFIED") UNSPECIFIED, + NEGLIGIBLE, + LOW, + MEDIUM, + HIGH; + + internal object InternalHarmProbabilitySerializer : + KSerializer by FirstOrdinalSerializer(InternalHarmProbability::class) + + internal fun toPublic() = + when (this) { + HIGH -> HarmProbability.HIGH + MEDIUM -> HarmProbability.MEDIUM + LOW -> HarmProbability.LOW + NEGLIGIBLE -> HarmProbability.NEGLIGIBLE + else -> HarmProbability.UNKNOWN + } + } public companion object { /** A new and not yet supported value. */ @JvmField public val UNKNOWN: HarmProbability = HarmProbability(0) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt index 40fe73ca906..b35997ed089 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt @@ -16,8 +16,34 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** Represents the severity of a [HarmCategory] being applicable in a [SafetyRating]. */ public class HarmSeverity private constructor(public val ordinal: Int) { + @Serializable(InternalHarmSeverity.InternalHarmSeveritySerializer::class) + internal enum class InternalHarmSeverity { + UNKNOWN, + @SerialName("HARM_SEVERITY_UNSPECIFIED") UNSPECIFIED, + @SerialName("HARM_SEVERITY_NEGLIGIBLE") NEGLIGIBLE, + @SerialName("HARM_SEVERITY_LOW") LOW, + @SerialName("HARM_SEVERITY_MEDIUM") MEDIUM, + @SerialName("HARM_SEVERITY_HIGH") HIGH; + + internal object InternalHarmSeveritySerializer : + KSerializer by FirstOrdinalSerializer(InternalHarmSeverity::class) + + internal fun toPublic() = + when (this) { + HIGH -> HarmSeverity.HIGH + MEDIUM -> HarmSeverity.MEDIUM + LOW -> HarmSeverity.LOW + NEGLIGIBLE -> HarmSeverity.NEGLIGIBLE + else -> HarmSeverity.UNKNOWN + } + } public companion object { /** A new and not yet supported value. */ @JvmField public val UNKNOWN: HarmSeverity = HarmSeverity(0) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt index 41ddfcfbe41..fbbc3f63e51 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt @@ -17,15 +17,30 @@ package com.google.firebase.vertexai.type import android.graphics.Bitmap +import android.graphics.BitmapFactory +import com.google.firebase.vertexai.internal.util.BASE_64_FLAGS +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.JsonContentPolymorphicSerializer import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.jsonObject import org.json.JSONObject +import java.io.ByteArrayOutputStream /** Interface representing data sent to and received from requests. */ -public interface Part +public interface Part { +} /** Represents text or string based data sent to and received from requests. */ -public class TextPart(public val text: String) : Part +public class TextPart(public val text: String) : Part { + + @Serializable + internal data class InternalTextPart(val text: String) : InternalPart +} /** * Represents image data sent to and received from requests. When this is sent to the server it is @@ -42,7 +57,16 @@ public class ImagePart(public val image: Bitmap) : Part * @param mimeType an IANA standard MIME type. For supported values, see the * [Vertex AI documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#media_requirements) */ -public class InlineDataPart(public val inlineData: ByteArray, public val mimeType: String) : Part +public class InlineDataPart(public val inlineData: ByteArray, public val mimeType: String) : Part { + + @Serializable + internal data class InternalInlineDataPart(@SerialName("inline_data") val inlineData: InternalInlineData) : + InternalPart { + + @Serializable + internal data class InternalInlineData(@SerialName("mime_type") val mimeType: String, val data: Base64) + } +} /** * Represents function call name and params received from requests. @@ -51,7 +75,15 @@ public class InlineDataPart(public val inlineData: ByteArray, public val mimeTyp * @param args the function parameters and values as a [Map] */ public class FunctionCallPart(public val name: String, public val args: Map) : - Part + Part { + + @Serializable + internal data class InternalFunctionCallPart(val functionCall: InternalFunctionCall) : InternalPart { + + @Serializable + internal data class InternalFunctionCall(val name: String, val args: Map? = null) + } +} /** * Represents function call output to be returned to the model when it requests a function call. @@ -59,7 +91,16 @@ public class FunctionCallPart(public val name: String, public val args: Map(InternalPart::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + return when { + "text" in jsonObject -> TextPart.InternalTextPart.serializer() + "functionCall" in jsonObject -> FunctionCallPart.InternalFunctionCallPart.serializer() + "functionResponse" in jsonObject -> FunctionResponsePart.InternalFunctionResponsePart.serializer() + "inlineData" in jsonObject -> InlineDataPart.InternalInlineDataPart.serializer() + "fileData" in jsonObject -> FileDataPart.InternalFileDataPart.serializer() + else -> throw SerializationException("Unknown Part type") + } + } +} + +internal fun Part.toInternal(): InternalPart { + return when (this) { + is TextPart -> TextPart.InternalTextPart(text) + is ImagePart -> + InlineDataPart.InternalInlineDataPart( + InlineDataPart.InternalInlineDataPart.InternalInlineData( + "image/jpeg", + encodeBitmapToBase64Png(image) + ) + ) + is InlineDataPart -> + InlineDataPart.InternalInlineDataPart( + InlineDataPart.InternalInlineDataPart.InternalInlineData( + mimeType, + android.util.Base64.encodeToString(inlineData, BASE_64_FLAGS) + ) + ) + is FunctionCallPart -> + FunctionCallPart.InternalFunctionCallPart( + FunctionCallPart.InternalFunctionCallPart.InternalFunctionCall( + name, + args + ) + ) + is FunctionResponsePart -> + FunctionResponsePart.InternalFunctionResponsePart( + FunctionResponsePart.InternalFunctionResponsePart.InternalFunctionResponse( + name, + response + ) + ) + is FileDataPart -> + FileDataPart.InternalFileDataPart( + FileDataPart.InternalFileDataPart.InternalFileData(mimeType = mimeType, fileUri = uri) + ) + else -> + throw com.google.firebase.vertexai.type.SerializationException( + "The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet." + ) + } +} + +private fun encodeBitmapToBase64Png(input: Bitmap): String { + ByteArrayOutputStream().let { + input.compress(Bitmap.CompressFormat.JPEG, 80, it) + return android.util.Base64.encodeToString(it.toByteArray(), BASE_64_FLAGS) + } +} + +internal fun InternalPart.toPublic(): Part { + return when (this) { + is TextPart.InternalTextPart -> TextPart(text) + is InlineDataPart.InternalInlineDataPart -> { + val data = android.util.Base64.decode(inlineData.data, BASE_64_FLAGS) + if (inlineData.mimeType.contains("image")) { + ImagePart(decodeBitmapFromImage(data)) + } else { + InlineDataPart(data, inlineData.mimeType) + } + } + is FunctionCallPart.InternalFunctionCallPart -> + FunctionCallPart( + functionCall.name, + functionCall.args.orEmpty().mapValues { it.value ?: JsonNull } + ) + is FunctionResponsePart.InternalFunctionResponsePart -> + FunctionResponsePart( + functionResponse.name, + functionResponse.response, + ) + is FileDataPart.InternalFileDataPart -> + FileDataPart(fileData.mimeType, fileData.fileUri) + else -> + throw com.google.firebase.vertexai.type.SerializationException( + "Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK." + ) + } +} + +private fun decodeBitmapFromImage(input: ByteArray) = + BitmapFactory.decodeByteArray(input, 0, input.size) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt index b4d06d04b8a..c2859595e82 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt @@ -16,6 +16,12 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer +import com.google.firebase.vertexai.internal.util.toPublic +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** * Feedback on the prompt provided in the request. * @@ -27,10 +33,46 @@ public class PromptFeedback( public val blockReason: BlockReason?, public val safetyRatings: List, public val blockReasonMessage: String? -) +) { + + @Serializable + internal data class InternalPromptFeedback( + val blockReason: BlockReason.InternalBlockReason? = null, + val safetyRatings: List? = null, + val blockReasonMessage: String? = null, + ) { + + internal fun toPublic(): PromptFeedback { + val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() + return PromptFeedback( + blockReason?.toPublic(), + safetyRatings, + blockReasonMessage + ) + } + } +} /** Describes why content was blocked. */ public class BlockReason private constructor(public val name: String, public val ordinal: Int) { + + @Serializable(InternalBlockReason.InternalBlockReasonSerializer::class) + internal enum class InternalBlockReason { + UNKNOWN, + @SerialName("BLOCKED_REASON_UNSPECIFIED") UNSPECIFIED, + SAFETY, + OTHER; + + internal object InternalBlockReasonSerializer : + KSerializer by FirstOrdinalSerializer(InternalBlockReason::class) + + internal fun toPublic() = + when (this) { + SAFETY -> BlockReason.SAFETY + OTHER -> BlockReason.OTHER + else -> BlockReason.UNKNOWN + } + } public companion object { /** A new and not yet supported value. */ @JvmField public val UNKNOWN: BlockReason = BlockReason("UNKNOWN", 0) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt index 68d7f93aa99..62a3e3d3593 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt @@ -16,6 +16,8 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable + /** * A configuration for a [HarmBlockThreshold] of some [HarmCategory] allowed and blocked in * responses. @@ -29,4 +31,18 @@ public class SafetySetting( internal val harmCategory: HarmCategory, internal val threshold: HarmBlockThreshold, internal val method: HarmBlockMethod? = null, -) +) { + internal fun toInternal() = + InternalSafetySetting( + harmCategory.toInternal(), + threshold.toInternal(), + method?.toInternal() + ) + + @Serializable + internal data class InternalSafetySetting( + val category: HarmCategory.InternalHarmCategory, + val threshold: HarmBlockThreshold.InternalHarmBlockThreshold, + val method: HarmBlockMethod.InternalHarmBlockMethod? = null, + ) +} diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt index b6f69e51d49..d507b3b009c 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt @@ -16,6 +16,8 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable + public abstract class StringFormat private constructor(internal val value: String) { public class Custom(value: String) : StringFormat(value) } @@ -238,4 +240,28 @@ internal constructor( type = "STRING", ) } + + internal fun toInternal(): InternalSchema = + InternalSchema( + type, + description, + format, + nullable, + enum, + properties?.mapValues { it.value.toInternal() }, + required, + items?.toInternal(), + ) + @Serializable + internal data class InternalSchema( + val type: String, + val description: String? = null, + val format: String? = null, + val nullable: Boolean? = false, + val enum: List? = null, + val properties: Map? = null, + val required: List? = null, + val items: InternalSchema? = null, + ) } + diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt index 41cbf99f6c4..87d2961c965 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt @@ -16,6 +16,9 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + /** * Contains a set of function declarations that the model has access to. These can be used to gather * information, or complete tasks @@ -24,6 +27,16 @@ package com.google.firebase.vertexai.type */ public class Tool internal constructor(internal val functionDeclarations: List?) { + internal fun toInternal() = + InternalTool( + functionDeclarations?.map { it.toInternal() } ?: emptyList() + ) + @Serializable + internal data class InternalTool( + val functionDeclarations: List? = null, + // This is a json object because it is not possible to make a data class with no parameters. + val codeExecution: JsonObject? = null, + ) public companion object { /** diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt index ee26f6b0f57..3417000ba6d 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt @@ -16,10 +16,39 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + /** * Contains configuration for the function calling tools of the model. This can be used to change * when the model can predict function calls. * * @param functionCallingConfig The config for function calling */ -public class ToolConfig(internal val functionCallingConfig: FunctionCallingConfig?) +public class ToolConfig(internal val functionCallingConfig: FunctionCallingConfig?) { + + internal fun toInternal() = + InternalToolConfig( + functionCallingConfig?.let { + FunctionCallingConfig.InternalFunctionCallingConfig( + when (it.mode) { + FunctionCallingConfig.Mode.ANY -> + FunctionCallingConfig.InternalFunctionCallingConfig.Mode.ANY + + FunctionCallingConfig.Mode.AUTO -> + FunctionCallingConfig.InternalFunctionCallingConfig.Mode.AUTO + + FunctionCallingConfig.Mode.NONE -> + FunctionCallingConfig.InternalFunctionCallingConfig.Mode.NONE + }, + it.allowedFunctionNames + ) + } + ) + + @Serializable + internal data class InternalToolConfig( + @SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig.InternalFunctionCallingConfig? + ) +} + diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Type.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Type.kt index ff33240aa24..f58f289cdbb 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Type.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Type.kt @@ -15,3 +15,26 @@ */ package com.google.firebase.vertexai.type + +import kotlinx.serialization.Serializable + +internal sealed interface Response + +@Serializable +internal data class GRpcErrorResponse(val error: GRpcError) : Response { + + @Serializable + internal data class GRpcError( + val code: Int, + val message: String, + val details: List? = null + ) { + + @Serializable + internal data class GRpcErrorDetails( + val reason: String? = null, + val domain: String? = null, + val metadata: Map? = null + ) + } +} diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt index 21da0255cb9..25f2d5eeddb 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt @@ -16,6 +16,8 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.Serializable + /** * Usage metadata about response(s). * @@ -27,4 +29,17 @@ public class UsageMetadata( public val promptTokenCount: Int, public val candidatesTokenCount: Int?, public val totalTokenCount: Int -) +) { + + @Serializable + internal data class InternalUsageMetadata( + val promptTokenCount: Int? = null, + val candidatesTokenCount: Int? = null, + val totalTokenCount: Int? = null, + ) { + + internal fun toPublic(): UsageMetadata = + UsageMetadata(promptTokenCount ?: 0, candidatesTokenCount ?: 0, totalTokenCount ?: 0) + } +} + diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt index 8b668371a31..c8e95fb74de 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt @@ -17,12 +17,12 @@ package com.google.firebase.vertexai import com.google.firebase.vertexai.common.APIController -import com.google.firebase.vertexai.common.GenerateContentResponse import com.google.firebase.vertexai.common.JSON -import com.google.firebase.vertexai.common.server.Candidate -import com.google.firebase.vertexai.common.shared.Content -import com.google.firebase.vertexai.common.shared.TextPart +import com.google.firebase.vertexai.type.InternalTextPart import com.google.firebase.vertexai.common.util.doBlocking +import com.google.firebase.vertexai.type.Candidate +import com.google.firebase.vertexai.type.Content +import com.google.firebase.vertexai.type.GenerateContentResponse import com.google.firebase.vertexai.type.RequestOptions import com.google.firebase.vertexai.type.ServerException import com.google.firebase.vertexai.type.content @@ -129,7 +129,17 @@ internal class GenerativeModelTesting { private fun generateContentResponseAsJsonString(text: String): String { return JSON.encodeToString( - GenerateContentResponse(listOf(Candidate(Content(parts = listOf(TextPart(text)))))) + GenerateContentResponse.InternalGenerateContentResponse( + listOf( + Candidate.InternalCandidate( + Content.InternalContent( + parts = listOf( + InternalTextPart(text) + ) + ) + ) + ) + ) ) } } diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/SchemaTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/InternalSchemaTests.kt similarity index 99% rename from firebase-vertexai/src/test/java/com/google/firebase/vertexai/SchemaTests.kt rename to firebase-vertexai/src/test/java/com/google/firebase/vertexai/InternalSchemaTests.kt index 747f65ac168..df2132d7707 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/SchemaTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/InternalSchemaTests.kt @@ -24,7 +24,7 @@ import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json import org.junit.Test -internal class SchemaTests { +internal class InternalSchemaTests { @Test fun `basic schema declaration`() { val schemaDeclaration = diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt index 8937b13569b..7b7d4d5720c 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt @@ -17,16 +17,17 @@ package com.google.firebase.vertexai.common import com.google.firebase.vertexai.BuildConfig -import com.google.firebase.vertexai.common.client.FunctionCallingConfig -import com.google.firebase.vertexai.common.client.Tool -import com.google.firebase.vertexai.common.client.ToolConfig -import com.google.firebase.vertexai.common.shared.Content -import com.google.firebase.vertexai.common.shared.TextPart +import com.google.firebase.vertexai.type.InternalTextPart import com.google.firebase.vertexai.common.util.commonTest import com.google.firebase.vertexai.common.util.createResponses import com.google.firebase.vertexai.common.util.doBlocking import com.google.firebase.vertexai.common.util.prepareStreamingResponse +import com.google.firebase.vertexai.type.Content +import com.google.firebase.vertexai.type.CountTokensResponse +import com.google.firebase.vertexai.type.FunctionCallingConfig import com.google.firebase.vertexai.type.RequestOptions +import com.google.firebase.vertexai.type.Tool +import com.google.firebase.vertexai.type.ToolConfig import io.kotest.assertions.json.shouldContainJsonKey import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe @@ -140,7 +141,7 @@ internal class RequestFormatTests { @Test fun `client id header is set correctly in the request`() = doBlocking { - val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10)) + val response = JSON.encodeToString(CountTokensResponse.InternalCountTokensResponse(totalTokens = 10)) val mockEngine = MockEngine { respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) } @@ -183,14 +184,14 @@ internal class RequestFormatTests { .generateContentStream( GenerateContentRequest( model = "unused", - contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))), + contents = listOf(Content.InternalContent(parts = listOf(InternalTextPart("Arbitrary")))), toolConfig = - ToolConfig( - FunctionCallingConfig( - mode = FunctionCallingConfig.Mode.ANY, - allowedFunctionNames = listOf("allowedFunctionName") - ) + ToolConfig.InternalToolConfig( + FunctionCallingConfig.InternalFunctionCallingConfig( + mode = FunctionCallingConfig.InternalFunctionCallingConfig.Mode.ANY, + allowedFunctionNames = listOf("allowedFunctionName") ) + ) ), ) .collect { channel.close() } @@ -205,7 +206,7 @@ internal class RequestFormatTests { @Test fun `headers from HeaderProvider are added to the request`() = doBlocking { - val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10)) + val response = JSON.encodeToString(CountTokensResponse.InternalCountTokensResponse(totalTokens = 10)) val mockEngine = MockEngine { respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) } @@ -237,7 +238,7 @@ internal class RequestFormatTests { @Test fun `headers from HeaderProvider are ignored if timeout`() = doBlocking { - val response = JSON.encodeToString(CountTokensResponse(totalTokens = 10)) + val response = JSON.encodeToString(CountTokensResponse.InternalCountTokensResponse(totalTokens = 10)) val mockEngine = MockEngine { respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) } @@ -291,8 +292,8 @@ internal class RequestFormatTests { .generateContentStream( GenerateContentRequest( model = "unused", - contents = listOf(Content(parts = listOf(TextPart("Arbitrary")))), - tools = listOf(Tool(codeExecution = JsonObject(emptyMap()))), + contents = listOf(Content.InternalContent(parts = listOf(InternalTextPart("Arbitrary")))), + tools = listOf(Tool.InternalTool(codeExecution = JsonObject(emptyMap()))), ) ) .collect { channel.close() } @@ -351,7 +352,7 @@ internal class ModelNamingTests(private val modelName: String, private val actua internal fun textGenerateContentRequest(prompt: String) = GenerateContentRequest( model = "unused", - contents = listOf(Content(parts = listOf(TextPart(prompt)))), + contents = listOf(Content.InternalContent(parts = listOf(InternalTextPart(prompt)))), ) internal fun textCountTokenRequest(prompt: String) = diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt index 9d470c95ea6..11e77460f64 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt @@ -16,11 +16,11 @@ package com.google.firebase.vertexai.common -import com.google.firebase.vertexai.common.server.BlockReason -import com.google.firebase.vertexai.common.server.FinishReason -import com.google.firebase.vertexai.common.shared.HarmCategory -import com.google.firebase.vertexai.common.shared.TextPart +import com.google.firebase.vertexai.type.InternalFinishReason +import com.google.firebase.vertexai.type.InternalTextPart import com.google.firebase.vertexai.common.util.goldenStreamingFile +import com.google.firebase.vertexai.type.BlockReason +import com.google.firebase.vertexai.type.HarmCategory import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.shouldBe @@ -43,7 +43,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val responseList = responses.toList() responseList.isEmpty() shouldBe false - responseList.first().candidates?.first()?.finishReason shouldBe FinishReason.STOP + responseList.first().candidates?.first()?.finishReason shouldBe InternalFinishReason.STOP responseList.first().candidates?.first()?.content?.parts?.isEmpty() shouldBe false responseList.first().candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } @@ -58,7 +58,7 @@ internal class StreamingSnapshotTests { val responseList = responses.toList() responseList.isEmpty() shouldBe false responseList.forEach { - it.candidates?.first()?.finishReason shouldBe FinishReason.STOP + it.candidates?.first()?.finishReason shouldBe InternalFinishReason.STOP it.candidates?.first()?.content?.parts?.isEmpty() shouldBe false it.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } @@ -75,7 +75,7 @@ internal class StreamingSnapshotTests { responseList.isEmpty() shouldBe false responseList.any { it.candidates?.any { - it.safetyRatings?.any { it.category == HarmCategory.UNKNOWN } ?: false + it.safetyRatings?.any { it.category == HarmCategory.InternalHarmCategory.UNKNOWN } ?: false } ?: false } shouldBe true @@ -91,7 +91,7 @@ internal class StreamingSnapshotTests { val responseList = responses.toList() responseList.isEmpty() shouldBe false - val part = responseList.first().candidates?.first()?.content?.parts?.first() as? TextPart + val part = responseList.first().candidates?.first()?.content?.parts?.first() as? InternalTextPart part.shouldNotBeNull() part.text shouldContain "\"" } @@ -104,7 +104,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.promptFeedback?.blockReason shouldBe BlockReason.SAFETY + exception.response.promptFeedback?.blockReason shouldBe BlockReason.InternalBlockReason.SAFETY } } @@ -131,7 +131,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.SAFETY + exception.response.candidates?.first()?.finishReason shouldBe InternalFinishReason.SAFETY } } @@ -170,7 +170,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.RECITATION + exception.response.candidates?.first()?.finishReason shouldBe InternalFinishReason.RECITATION } } diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt index 66e6a3f53a5..8b2a31fe149 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt @@ -16,15 +16,15 @@ package com.google.firebase.vertexai.common -import com.google.firebase.vertexai.common.server.BlockReason -import com.google.firebase.vertexai.common.server.FinishReason -import com.google.firebase.vertexai.common.server.HarmProbability -import com.google.firebase.vertexai.common.server.HarmSeverity -import com.google.firebase.vertexai.common.shared.FunctionCallPart -import com.google.firebase.vertexai.common.shared.HarmCategory -import com.google.firebase.vertexai.common.shared.TextPart +import com.google.firebase.vertexai.type.InternalFinishReason +import com.google.firebase.vertexai.type.InternalFunctionCallPart +import com.google.firebase.vertexai.type.InternalTextPart import com.google.firebase.vertexai.common.util.goldenUnaryFile import com.google.firebase.vertexai.common.util.shouldNotBeNullOrEmpty +import com.google.firebase.vertexai.type.BlockReason +import com.google.firebase.vertexai.type.HarmCategory +import com.google.firebase.vertexai.type.HarmProbability +import com.google.firebase.vertexai.type.HarmSeverity import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.collections.shouldNotBeEmpty import io.kotest.matchers.nulls.shouldNotBeNull @@ -53,7 +53,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates?.first()?.finishReason shouldBe InternalFinishReason.STOP response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } @@ -67,7 +67,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates?.first()?.finishReason shouldBe InternalFinishReason.STOP response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } @@ -81,7 +81,7 @@ internal class UnarySnapshotTests { response.candidates?.isNullOrEmpty() shouldBe false val candidate = response.candidates?.first() - candidate?.safetyRatings?.any { it.category == HarmCategory.UNKNOWN } shouldBe true + candidate?.safetyRatings?.any { it.category == HarmCategory.InternalHarmCategory.UNKNOWN } shouldBe true } } @@ -94,12 +94,12 @@ internal class UnarySnapshotTests { response.candidates?.isEmpty() shouldBe false response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false response.candidates?.first()?.safetyRatings?.all { - it.probability == HarmProbability.NEGLIGIBLE + it.probability == HarmProbability.InternalHarmProbability.NEGLIGIBLE } shouldBe true response.candidates?.first()?.safetyRatings?.all { it.probabilityScore != null } shouldBe true response.candidates?.first()?.safetyRatings?.all { - it.severity == HarmSeverity.NEGLIGIBLE + it.severity == HarmSeverity.InternalHarmSeverity.NEGLIGIBLE } shouldBe true response.candidates?.first()?.safetyRatings?.all { it.severityScore != null } shouldBe true } @@ -111,7 +111,7 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { shouldThrow { apiController.generateContent(textGenerateContentRequest("prompt")) - } should { it.response.promptFeedback?.blockReason shouldBe BlockReason.SAFETY } + } should { it.response.promptFeedback?.blockReason shouldBe BlockReason.InternalBlockReason.SAFETY } } } @@ -153,7 +153,7 @@ internal class UnarySnapshotTests { shouldThrow { apiController.generateContent(textGenerateContentRequest("prompt")) } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.SAFETY + exception.response.candidates?.first()?.finishReason shouldBe InternalFinishReason.SAFETY } } @@ -191,7 +191,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates?.first()?.finishReason shouldBe InternalFinishReason.STOP response.usageMetadata shouldNotBe null response.usageMetadata?.totalTokenCount shouldBe 363 } @@ -204,7 +204,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe FinishReason.STOP + response.candidates?.first()?.finishReason shouldBe InternalFinishReason.STOP response.usageMetadata shouldNotBe null response.usageMetadata?.promptTokenCount shouldBe 6 response.usageMetadata?.totalTokenCount shouldBe null @@ -231,7 +231,7 @@ internal class UnarySnapshotTests { response.candidates?.isEmpty() shouldBe false with( - response.candidates?.first()?.content?.parts?.first()?.shouldBeInstanceOf() + response.candidates?.first()?.content?.parts?.first()?.shouldBeInstanceOf() ) { shouldNotBeNull() JSON.decodeFromString>(text).shouldNotBeEmpty() @@ -315,7 +315,7 @@ internal class UnarySnapshotTests { goldenUnaryFile("success-function-call-null.json") { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - val callPart = (response.candidates!!.first().content!!.parts.first() as FunctionCallPart) + val callPart = (response.candidates!!.first().content!!.parts.first() as InternalFunctionCallPart) callPart.functionCall.args shouldNotBe null callPart.functionCall.args?.get("season") shouldBe null @@ -333,7 +333,7 @@ internal class UnarySnapshotTests { content.let { it.shouldNotBeNull() it.parts.shouldNotBeEmpty() - it.parts.first().shouldBeInstanceOf() + it.parts.first().shouldBeInstanceOf() } callPart.functionCall.args shouldNotBe null @@ -349,7 +349,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) val content = response.candidates.shouldNotBeNullOrEmpty().first().content content.shouldNotBeNull() - val callPart = content.parts.shouldNotBeNullOrEmpty().first() as FunctionCallPart + val callPart = content.parts.shouldNotBeNullOrEmpty().first() as InternalFunctionCallPart callPart.functionCall.name shouldBe "current_time" callPart.functionCall.args shouldBe null diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt index 8d1e3bf9f00..c46535eb841 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt @@ -20,11 +20,11 @@ package com.google.firebase.vertexai.common.util import com.google.firebase.vertexai.common.APIController import com.google.firebase.vertexai.common.GenerateContentRequest -import com.google.firebase.vertexai.common.GenerateContentResponse import com.google.firebase.vertexai.common.JSON -import com.google.firebase.vertexai.common.server.Candidate -import com.google.firebase.vertexai.common.shared.Content -import com.google.firebase.vertexai.common.shared.TextPart +import com.google.firebase.vertexai.type.Candidate +import com.google.firebase.vertexai.type.Content +import com.google.firebase.vertexai.type.GenerateContentResponse +import com.google.firebase.vertexai.type.InternalTextPart import com.google.firebase.vertexai.type.RequestOptions import io.kotest.matchers.collections.shouldNotBeEmpty import io.kotest.matchers.nulls.shouldNotBeNull @@ -43,15 +43,15 @@ import kotlinx.serialization.encodeToString private val TEST_CLIENT_ID = "genai-android/test" -internal fun prepareStreamingResponse(response: List): List = +internal fun prepareStreamingResponse(response: List): List = response.map { "data: ${JSON.encodeToString(it)}$SSE_SEPARATOR".toByteArray() } -internal fun prepareResponse(response: GenerateContentResponse) = +internal fun prepareResponse(response: GenerateContentResponse.InternalGenerateContentResponse) = JSON.encodeToString(response).toByteArray() @OptIn(ExperimentalSerializationApi::class) internal fun createRequest(vararg text: String): GenerateContentRequest { - val contents = text.map { Content(parts = listOf(TextPart(it))) } + val contents = text.map { Content.InternalContent(parts = listOf(InternalTextPart(it))) } return GenerateContentRequest("gemini", contents) } @@ -59,10 +59,24 @@ internal fun createRequest(vararg text: String): GenerateContentRequest { internal fun createResponse(text: String) = createResponses(text).single() @OptIn(ExperimentalSerializationApi::class) -internal fun createResponses(vararg text: String): List { - val candidates = text.map { Candidate(Content(parts = listOf(TextPart(it)))) } +internal fun createResponses(vararg text: String): List { + val candidates = text.map { + Candidate.InternalCandidate( + Content.InternalContent( + parts = listOf( + InternalTextPart(it) + ) + ) + ) + } - return candidates.map { GenerateContentResponse(candidates = listOf(it)) } + return candidates.map { + GenerateContentResponse.InternalGenerateContentResponse( + candidates = listOf( + it + ) + ) + } } /** From 2bcf5d3376ddd058fc6874c8e7d5c81a68405bed Mon Sep 17 00:00:00 2001 From: Emily Ploszaj Date: Wed, 22 Jan 2025 17:12:28 -0600 Subject: [PATCH 2/9] Adjust naming and fix test compilation --- .../firebase/vertexai/common/APIController.kt | 23 ++++--- .../firebase/vertexai/common/Exceptions.kt | 4 +- .../firebase/vertexai/common/Request.kt | 18 +++--- .../firebase/vertexai/type/Candidate.kt | 52 ++++++++-------- .../google/firebase/vertexai/type/Content.kt | 4 +- .../vertexai/type/CountTokensResponse.kt | 2 +- .../vertexai/type/FunctionCallingConfig.kt | 2 +- .../vertexai/type/FunctionDeclaration.kt | 6 +- .../vertexai/type/GenerateContentResponse.kt | 8 +-- .../vertexai/type/GenerationConfig.kt | 6 +- .../firebase/vertexai/type/HarmBlockMethod.kt | 6 +- .../vertexai/type/HarmBlockThreshold.kt | 10 ++-- .../firebase/vertexai/type/HarmCategory.kt | 20 +++---- .../firebase/vertexai/type/HarmProbability.kt | 8 +-- .../firebase/vertexai/type/HarmSeverity.kt | 8 +-- .../com/google/firebase/vertexai/type/Part.kt | 60 +++++++++---------- .../firebase/vertexai/type/PromptFeedback.kt | 15 +++-- .../firebase/vertexai/type/SafetySetting.kt | 10 ++-- .../google/firebase/vertexai/type/Schema.kt | 10 ++-- .../com/google/firebase/vertexai/type/Tool.kt | 6 +- .../firebase/vertexai/type/ToolConfig.kt | 14 ++--- .../firebase/vertexai/type/UsageMetadata.kt | 2 +- .../vertexai/GenerativeModelTesting.kt | 10 ++-- ...nternalSchemaTests.kt => InternalTests.kt} | 2 +- .../vertexai/common/APIControllerTests.kt | 22 +++---- .../vertexai/common/StreamingSnapshotTests.kt | 18 +++--- .../vertexai/common/UnarySnapshotTests.kt | 32 +++++----- .../firebase/vertexai/common/util/tests.kt | 18 +++--- 28 files changed, 197 insertions(+), 199 deletions(-) rename firebase-vertexai/src/test/java/com/google/firebase/vertexai/{InternalSchemaTests.kt => InternalTests.kt} (99%) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt index 12ac704919e..b5dd0b9a0b2 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt @@ -19,13 +19,12 @@ package com.google.firebase.vertexai.common import android.util.Log import com.google.firebase.Firebase import com.google.firebase.options -import com.google.firebase.vertexai.common.server.GRpcError -import com.google.firebase.vertexai.common.server.GRpcErrorDetails import com.google.firebase.vertexai.common.util.decodeToFlow import com.google.firebase.vertexai.common.util.fullModelName +import com.google.firebase.vertexai.type.CountTokensResponse +import com.google.firebase.vertexai.type.FinishReason import com.google.firebase.vertexai.type.GRpcErrorResponse -import com.google.firebase.vertexai.type.InternalCountTokensResponse -import com.google.firebase.vertexai.type.InternalGenerateContentResponse +import com.google.firebase.vertexai.type.GenerateContentResponse import com.google.firebase.vertexai.type.RequestOptions import com.google.firebase.vertexai.type.Response import io.ktor.client.HttpClient @@ -109,7 +108,7 @@ internal constructor( install(ContentNegotiation) { json(JSON) } } - suspend fun generateContent(request: GenerateContentRequest): InternalGenerateContentResponse = + suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse.Internal = try { client .post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:generateContent") { @@ -117,15 +116,15 @@ internal constructor( applyHeaderProvider() } .also { validateResponse(it) } - .body() + .body() .validate() } catch (e: Throwable) { throw FirebaseCommonAIException.from(e) } - fun generateContentStream(request: GenerateContentRequest): Flow = + fun generateContentStream(request: GenerateContentRequest): Flow = client - .postStream( + .postStream( "${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse" ) { applyCommonConfiguration(request) @@ -133,7 +132,7 @@ internal constructor( .map { it.validate() } .catch { throw FirebaseCommonAIException.from(it) } - suspend fun countTokens(request: CountTokensRequest): InternalCountTokensResponse = + suspend fun countTokens(request: CountTokensRequest): CountTokensResponse.Internal = try { client .post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:countTokens") { @@ -278,19 +277,19 @@ private suspend fun validateResponse(response: HttpResponse) { throw ServerException(message) } -private fun getServiceDisabledErrorDetailsOrNull(error: GRpcError): GRpcErrorDetails? { +private fun getServiceDisabledErrorDetailsOrNull(error: GRpcErrorResponse.GRpcError): GRpcErrorResponse.GRpcError.GRpcErrorDetails? { return error.details?.firstOrNull { it.reason == "SERVICE_DISABLED" && it.domain == "googleapis.com" } } -private fun InternalGenerateContentResponse.validate() = apply { +private fun GenerateContentResponse.Internal.validate() = apply { if ((candidates?.isEmpty() != false) && promptFeedback == null) { throw SerializationException("Error deserializing response, found no valid fields") } promptFeedback?.blockReason?.let { throw PromptBlockedException(this) } candidates ?.mapNotNull { it.finishReason } - ?.firstOrNull { it != FinishReason.STOP } + ?.firstOrNull { it != FinishReason.Internal.STOP } ?.let { throw ResponseStoppedException(this) } } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt index 621debdf2c7..fd38b9813d1 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt @@ -67,7 +67,7 @@ internal class InvalidAPIKeyException(message: String, cause: Throwable? = null) * @property response the full server response for the request. */ internal class PromptBlockedException( - val response: GenerateContentResponse.InternalGenerateContentResponse, + val response: GenerateContentResponse.Internal, cause: Throwable? = null ) : FirebaseCommonAIException( @@ -99,7 +99,7 @@ internal class InvalidStateException(message: String, cause: Throwable? = null) * @property response the full server response for the request */ internal class ResponseStoppedException( - val response: GenerateContentResponse.InternalGenerateContentResponse, + val response: GenerateContentResponse.Internal, cause: Throwable? = null ) : FirebaseCommonAIException( diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt index 380526b552a..040a38e0a0b 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt @@ -30,21 +30,21 @@ internal sealed interface Request @Serializable internal data class GenerateContentRequest( val model: String? = null, - val contents: List, - @SerialName("safety_settings") val safetySettings: List? = null, - @SerialName("generation_config") val generationConfig: GenerationConfig.InternalGenerationConfig? = null, - val tools: List? = null, - @SerialName("tool_config") var toolConfig: ToolConfig.InternalToolConfig? = null, - @SerialName("system_instruction") val systemInstruction: Content.InternalContent? = null, + val contents: List, + @SerialName("safety_settings") val safetySettings: List? = null, + @SerialName("generation_config") val generationConfig: GenerationConfig.Internal? = null, + val tools: List? = null, + @SerialName("tool_config") var toolConfig: ToolConfig.Internal? = null, + @SerialName("system_instruction") val systemInstruction: Content.Internal? = null, ) : Request @Serializable internal data class CountTokensRequest( val generateContentRequest: GenerateContentRequest? = null, val model: String? = null, - val contents: List? = null, - val tools: List? = null, - @SerialName("system_instruction") val systemInstruction: Content.InternalContent? = null, + val contents: List? = null, + val tools: List? = null, + @SerialName("system_instruction") val systemInstruction: Content.Internal? = null, ) : Request { companion object { fun forGenAI(generateContentRequest: GenerateContentRequest) = diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt index 2962edbcb65..6ada813b210 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt @@ -41,12 +41,12 @@ internal constructor( ) { @Serializable - internal data class InternalCandidate( - val content: Content.InternalContent? = null, - val finishReason: FinishReason.InternalFinishReason? = null, - val safetyRatings: List? = null, - val citationMetadata: CitationMetadata.InternalCitationMetadata? = null, - val groundingMetadata: InternalGroundingMetadata? = null, + internal data class Internal( + val content: Content.Internal? = null, + val finishReason: FinishReason.Internal? = null, + val safetyRatings: List? = null, + val citationMetadata: CitationMetadata.Internal? = null, + val groundingMetadata: GroundingMetadata? = null, ) { internal fun toPublic(): Candidate { val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() @@ -62,27 +62,27 @@ internal constructor( } @Serializable - internal data class InternalGroundingMetadata( + internal data class GroundingMetadata( @SerialName("web_search_queries") val webSearchQueries: List?, - @SerialName("search_entry_point") val searchEntryPoint: InternalSearchEntryPoint?, + @SerialName("search_entry_point") val searchEntryPoint: SearchEntryPoint?, @SerialName("retrieval_queries") val retrievalQueries: List?, - @SerialName("grounding_attribution") val groundingAttribution: List?, + @SerialName("grounding_attribution") val groundingAttribution: List?, ) { @Serializable - internal data class InternalSearchEntryPoint( + internal data class SearchEntryPoint( @SerialName("rendered_content") val renderedContent: String?, @SerialName("sdk_blob") val sdkBlob: String?, ) @Serializable - internal data class InternalGroundingAttribution( - val segment: InternalSegment, + internal data class GroundingAttribution( + val segment: Segment, @SerialName("confidence_score") val confidenceScore: Float?, ) { @Serializable - internal data class InternalSegment( + internal data class Segment( @SerialName("start_index") val startIndex: Int, @SerialName("end_index") val endIndex: Int, ) @@ -115,12 +115,12 @@ internal constructor( ) { @Serializable - internal data class InternalSafetyRating @JvmOverloads constructor( - val category: HarmCategory.InternalHarmCategory, - val probability: HarmProbability.InternalHarmProbability, + internal data class Internal @JvmOverloads constructor( + val category: HarmCategory.Internal, + val probability: HarmProbability.Internal, val blocked: Boolean? = null, // TODO(): any reason not to default to false? val probabilityScore: Float? = null, - val severity: HarmSeverity.InternalHarmSeverity? = null, + val severity: HarmSeverity.Internal? = null, val severityScore: Float? = null, ) { @@ -146,9 +146,9 @@ internal constructor( public class CitationMetadata internal constructor(public val citations: List) { @Serializable - internal data class InternalCitationMetadata + internal data class Internal @OptIn(ExperimentalSerializationApi::class) - internal constructor(@JsonNames("citations") val citationSources: List) { + internal constructor(@JsonNames("citations") val citationSources: List) { internal fun toPublic() = CitationMetadata(citationSources.map { it.toPublic() }) @@ -182,13 +182,13 @@ internal constructor( ) { @Serializable - internal data class InternalCitationSources( + internal data class Internal( val title: String? = null, val startIndex: Int = 0, val endIndex: Int, val uri: String? = null, val license: String? = null, - val publicationDate: InternalDate? = null, + val publicationDate: Date? = null, ) { internal fun toPublic(): Citation { @@ -216,7 +216,7 @@ internal constructor( } @Serializable - internal data class InternalDate( + internal data class Date( /** Year of the date. Must be between 1 and 9999, or 0 for no year. */ val year: Int? = null, /** 1-based index for month. Must be from 1 to 12, or 0 to specify a year without a month. */ @@ -238,8 +238,8 @@ internal constructor( */ public class FinishReason private constructor(public val name: String, public val ordinal: Int) { - @Serializable(InternalFinishReason.InternalFinishReasonSerializer::class) - internal enum class InternalFinishReason { + @Serializable(Internal.Serializer::class) + internal enum class Internal { UNKNOWN, @SerialName("FINISH_REASON_UNSPECIFIED") UNSPECIFIED, STOP, @@ -248,8 +248,8 @@ public class FinishReason private constructor(public val name: String, public va RECITATION, OTHER; - internal object InternalFinishReasonSerializer : - KSerializer by FirstOrdinalSerializer(InternalFinishReason::class) + internal object Serializer : + KSerializer by FirstOrdinalSerializer(Internal::class) internal fun toPublic() = when (this) { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt index ef9b455784f..c1a6c43bd1f 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt @@ -81,14 +81,14 @@ constructor(public val role: String? = "user", public val parts: List) { } internal fun toInternal() = - InternalContent( + Internal( this.role ?: "user", this.parts.map { it.toInternal() } ) @ExperimentalSerializationApi @Serializable - internal data class InternalContent(@EncodeDefault val role: String? = "user", val parts: List) { + internal data class Internal(@EncodeDefault val role: String? = "user", val parts: List) { internal fun toPublic(): Content = Content(role, parts.map { it.toPublic() }) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt index 47823ba4e88..4692ed7039a 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt @@ -40,7 +40,7 @@ public class CountTokensResponse( public operator fun component2(): Int? = totalBillableCharacters @Serializable - internal data class InternalCountTokensResponse( + internal data class Internal( val totalTokens: Int, val totalBillableCharacters: Int? = null ) : Response { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt index 2dd09e169ff..5805a7c1800 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt @@ -50,7 +50,7 @@ internal constructor( @Serializable - internal data class InternalFunctionCallingConfig( + internal data class Internal( val mode: Mode, @SerialName("allowed_function_names") val allowedFunctionNames: List? = null ) { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt index b806a6b1714..6452d92e4b6 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt @@ -62,12 +62,12 @@ public class FunctionDeclaration( Schema.obj(properties = parameters, optionalProperties = optionalParameters, nullable = false) internal fun toInternal() = - InternalFunctionDeclaration(name, "", schema.toInternal()) + Internal(name, "", schema.toInternal()) @Serializable - internal data class InternalFunctionDeclaration( + internal data class Internal( val name: String, val description: String, - val parameters: Schema.InternalSchema + val parameters: Schema.Internal ) } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerateContentResponse.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerateContentResponse.kt index 8c2452e8b6b..00395252914 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerateContentResponse.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerateContentResponse.kt @@ -45,10 +45,10 @@ public class GenerateContentResponse( } @Serializable - internal data class InternalGenerateContentResponse( - val candidates: List? = null, - val promptFeedback: PromptFeedback.InternalPromptFeedback? = null, - val usageMetadata: UsageMetadata.InternalUsageMetadata? = null, + internal data class Internal( + val candidates: List? = null, + val promptFeedback: PromptFeedback.Internal? = null, + val usageMetadata: UsageMetadata.Internal? = null, ) : Response { internal fun toPublic(): GenerateContentResponse { return GenerateContentResponse( diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt index c76fb8b7524..4abec8a260d 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt @@ -146,7 +146,7 @@ private constructor( } internal fun toInternal() = - InternalGenerationConfig( + Internal( temperature = temperature, topP = topP, topK = topK, @@ -160,7 +160,7 @@ private constructor( ) @Serializable - internal data class InternalGenerationConfig( + internal data class Internal( val temperature: Float?, @SerialName("top_p") val topP: Float?, @SerialName("top_k") val topK: Int?, @@ -170,7 +170,7 @@ private constructor( @SerialName("response_mime_type") val responseMimeType: String? = null, @SerialName("presence_penalty") val presencePenalty: Float? = null, @SerialName("frequency_penalty") val frequencyPenalty: Float? = null, - @SerialName("response_schema") val responseSchema: Schema.InternalSchema? = null, + @SerialName("response_schema") val responseSchema: Schema.Internal? = null, ) public companion object { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt index f42f63c0d54..d9acb7a28f6 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt @@ -27,13 +27,13 @@ import kotlinx.serialization.Serializable public class HarmBlockMethod private constructor(public val ordinal: Int) { internal fun toInternal() = when (this) { - SEVERITY -> InternalHarmBlockMethod.SEVERITY - PROBABILITY -> InternalHarmBlockMethod.PROBABILITY + SEVERITY -> Internal.SEVERITY + PROBABILITY -> Internal.PROBABILITY else -> throw makeMissingCaseException("HarmBlockMethod", ordinal) } @Serializable - internal enum class InternalHarmBlockMethod { + internal enum class Internal { @SerialName("HARM_BLOCK_METHOD_UNSPECIFIED") UNSPECIFIED, SEVERITY, PROBABILITY, diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt index 14307c17899..73e23be2341 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt @@ -25,15 +25,15 @@ public class HarmBlockThreshold private constructor(public val ordinal: Int) { internal fun toInternal() = when (this) { - NONE -> InternalHarmBlockThreshold.BLOCK_NONE - ONLY_HIGH -> InternalHarmBlockThreshold.BLOCK_ONLY_HIGH - MEDIUM_AND_ABOVE -> InternalHarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE - LOW_AND_ABOVE -> InternalHarmBlockThreshold.BLOCK_LOW_AND_ABOVE + NONE -> Internal.BLOCK_NONE + ONLY_HIGH -> Internal.BLOCK_ONLY_HIGH + MEDIUM_AND_ABOVE -> Internal.BLOCK_MEDIUM_AND_ABOVE + LOW_AND_ABOVE -> Internal.BLOCK_LOW_AND_ABOVE else -> throw makeMissingCaseException("HarmBlockThreshold", ordinal) } @Serializable - internal enum class InternalHarmBlockThreshold { + internal enum class Internal { @SerialName("HARM_BLOCK_THRESHOLD_UNSPECIFIED") UNSPECIFIED, BLOCK_LOW_AND_ABOVE, BLOCK_MEDIUM_AND_ABOVE, diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt index ab5fdd927c0..945febb2689 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt @@ -26,16 +26,16 @@ import kotlinx.serialization.Serializable public class HarmCategory private constructor(public val ordinal: Int) { internal fun toInternal() = when (this) { - HARASSMENT -> InternalHarmCategory.HARASSMENT - HATE_SPEECH -> InternalHarmCategory.HATE_SPEECH - SEXUALLY_EXPLICIT -> InternalHarmCategory.SEXUALLY_EXPLICIT - DANGEROUS_CONTENT -> InternalHarmCategory.DANGEROUS_CONTENT - CIVIC_INTEGRITY -> InternalHarmCategory.CIVIC_INTEGRITY - UNKNOWN -> InternalHarmCategory.UNKNOWN + HARASSMENT -> Internal.HARASSMENT + HATE_SPEECH -> Internal.HATE_SPEECH + SEXUALLY_EXPLICIT -> Internal.SEXUALLY_EXPLICIT + DANGEROUS_CONTENT -> Internal.DANGEROUS_CONTENT + CIVIC_INTEGRITY -> Internal.CIVIC_INTEGRITY + UNKNOWN -> Internal.UNKNOWN else -> throw makeMissingCaseException("HarmCategory", ordinal) } - @Serializable(InternalHarmCategory.HarmCategorySerializer::class) - internal enum class InternalHarmCategory { + @Serializable(Internal.Serializer::class) + internal enum class Internal { UNKNOWN, @SerialName("HARM_CATEGORY_HARASSMENT") HARASSMENT, @SerialName("HARM_CATEGORY_HATE_SPEECH") HATE_SPEECH, @@ -43,8 +43,8 @@ public class HarmCategory private constructor(public val ordinal: Int) { @SerialName("HARM_CATEGORY_DANGEROUS_CONTENT") DANGEROUS_CONTENT, @SerialName("HARM_CATEGORY_CIVIC_INTEGRITY") CIVIC_INTEGRITY; - internal object HarmCategorySerializer : - KSerializer by FirstOrdinalSerializer(InternalHarmCategory::class) + internal object Serializer : + KSerializer by FirstOrdinalSerializer(Internal::class) internal fun toPublic() = when (this) { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt index 19ab5ce524a..0e406a58dee 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt @@ -23,8 +23,8 @@ import kotlinx.serialization.Serializable /** Represents the probability that some [HarmCategory] is applicable in a [SafetyRating]. */ public class HarmProbability private constructor(public val ordinal: Int) { - @Serializable(InternalHarmProbability.InternalHarmProbabilitySerializer::class) - internal enum class InternalHarmProbability { + @Serializable(Internal.Serializer::class) + internal enum class Internal { UNKNOWN, @SerialName("HARM_PROBABILITY_UNSPECIFIED") UNSPECIFIED, NEGLIGIBLE, @@ -32,8 +32,8 @@ public class HarmProbability private constructor(public val ordinal: Int) { MEDIUM, HIGH; - internal object InternalHarmProbabilitySerializer : - KSerializer by FirstOrdinalSerializer(InternalHarmProbability::class) + internal object Serializer : + KSerializer by FirstOrdinalSerializer(Internal::class) internal fun toPublic() = when (this) { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt index b35997ed089..21c4066a728 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt @@ -23,8 +23,8 @@ import kotlinx.serialization.Serializable /** Represents the severity of a [HarmCategory] being applicable in a [SafetyRating]. */ public class HarmSeverity private constructor(public val ordinal: Int) { - @Serializable(InternalHarmSeverity.InternalHarmSeveritySerializer::class) - internal enum class InternalHarmSeverity { + @Serializable(Internal.Serializer::class) + internal enum class Internal { UNKNOWN, @SerialName("HARM_SEVERITY_UNSPECIFIED") UNSPECIFIED, @SerialName("HARM_SEVERITY_NEGLIGIBLE") NEGLIGIBLE, @@ -32,8 +32,8 @@ public class HarmSeverity private constructor(public val ordinal: Int) { @SerialName("HARM_SEVERITY_MEDIUM") MEDIUM, @SerialName("HARM_SEVERITY_HIGH") HIGH; - internal object InternalHarmSeveritySerializer : - KSerializer by FirstOrdinalSerializer(InternalHarmSeverity::class) + internal object Serializer : + KSerializer by FirstOrdinalSerializer(Internal::class) internal fun toPublic() = when (this) { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt index fbbc3f63e51..a0789644f54 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt @@ -39,7 +39,7 @@ public interface Part { public class TextPart(public val text: String) : Part { @Serializable - internal data class InternalTextPart(val text: String) : InternalPart + internal data class Internal(val text: String) : InternalPart } /** @@ -60,11 +60,11 @@ public class ImagePart(public val image: Bitmap) : Part public class InlineDataPart(public val inlineData: ByteArray, public val mimeType: String) : Part { @Serializable - internal data class InternalInlineDataPart(@SerialName("inline_data") val inlineData: InternalInlineData) : + internal data class Internal(@SerialName("inline_data") val inlineData: InlineData) : InternalPart { @Serializable - internal data class InternalInlineData(@SerialName("mime_type") val mimeType: String, val data: Base64) + internal data class InlineData(@SerialName("mime_type") val mimeType: String, val data: Base64) } } @@ -78,10 +78,10 @@ public class FunctionCallPart(public val name: String, public val args: Map? = null) + internal data class FunctionCall(val name: String, val args: Map? = null) } } @@ -94,11 +94,11 @@ public class FunctionCallPart(public val name: String, public val args: Map( override fun selectDeserializer(element: JsonElement): DeserializationStrategy { val jsonObject = element.jsonObject return when { - "text" in jsonObject -> TextPart.InternalTextPart.serializer() - "functionCall" in jsonObject -> FunctionCallPart.InternalFunctionCallPart.serializer() - "functionResponse" in jsonObject -> FunctionResponsePart.InternalFunctionResponsePart.serializer() - "inlineData" in jsonObject -> InlineDataPart.InternalInlineDataPart.serializer() - "fileData" in jsonObject -> FileDataPart.InternalFileDataPart.serializer() + "text" in jsonObject -> TextPart.Internal.serializer() + "functionCall" in jsonObject -> FunctionCallPart.Internal.serializer() + "functionResponse" in jsonObject -> FunctionResponsePart.Internal.serializer() + "inlineData" in jsonObject -> InlineDataPart.Internal.serializer() + "fileData" in jsonObject -> FileDataPart.Internal.serializer() else -> throw SerializationException("Unknown Part type") } } @@ -156,38 +156,38 @@ internal object PartSerializer : JsonContentPolymorphicSerializer( internal fun Part.toInternal(): InternalPart { return when (this) { - is TextPart -> TextPart.InternalTextPart(text) + is TextPart -> TextPart.Internal(text) is ImagePart -> - InlineDataPart.InternalInlineDataPart( - InlineDataPart.InternalInlineDataPart.InternalInlineData( + InlineDataPart.Internal( + InlineDataPart.Internal.InlineData( "image/jpeg", encodeBitmapToBase64Png(image) ) ) is InlineDataPart -> - InlineDataPart.InternalInlineDataPart( - InlineDataPart.InternalInlineDataPart.InternalInlineData( + InlineDataPart.Internal( + InlineDataPart.Internal.InlineData( mimeType, android.util.Base64.encodeToString(inlineData, BASE_64_FLAGS) ) ) is FunctionCallPart -> - FunctionCallPart.InternalFunctionCallPart( - FunctionCallPart.InternalFunctionCallPart.InternalFunctionCall( + FunctionCallPart.Internal( + FunctionCallPart.Internal.FunctionCall( name, args ) ) is FunctionResponsePart -> - FunctionResponsePart.InternalFunctionResponsePart( - FunctionResponsePart.InternalFunctionResponsePart.InternalFunctionResponse( + FunctionResponsePart.Internal( + FunctionResponsePart.Internal.FunctionResponse( name, response ) ) is FileDataPart -> - FileDataPart.InternalFileDataPart( - FileDataPart.InternalFileDataPart.InternalFileData(mimeType = mimeType, fileUri = uri) + FileDataPart.Internal( + FileDataPart.Internal.FileData(mimeType = mimeType, fileUri = uri) ) else -> throw com.google.firebase.vertexai.type.SerializationException( @@ -205,8 +205,8 @@ private fun encodeBitmapToBase64Png(input: Bitmap): String { internal fun InternalPart.toPublic(): Part { return when (this) { - is TextPart.InternalTextPart -> TextPart(text) - is InlineDataPart.InternalInlineDataPart -> { + is TextPart.Internal -> TextPart(text) + is InlineDataPart.Internal -> { val data = android.util.Base64.decode(inlineData.data, BASE_64_FLAGS) if (inlineData.mimeType.contains("image")) { ImagePart(decodeBitmapFromImage(data)) @@ -214,17 +214,17 @@ internal fun InternalPart.toPublic(): Part { InlineDataPart(data, inlineData.mimeType) } } - is FunctionCallPart.InternalFunctionCallPart -> + is FunctionCallPart.Internal -> FunctionCallPart( functionCall.name, functionCall.args.orEmpty().mapValues { it.value ?: JsonNull } ) - is FunctionResponsePart.InternalFunctionResponsePart -> + is FunctionResponsePart.Internal -> FunctionResponsePart( functionResponse.name, functionResponse.response, ) - is FileDataPart.InternalFileDataPart -> + is FileDataPart.Internal -> FileDataPart(fileData.mimeType, fileData.fileUri) else -> throw com.google.firebase.vertexai.type.SerializationException( diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt index c2859595e82..16ef10b9488 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt @@ -17,7 +17,6 @@ package com.google.firebase.vertexai.type import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer -import com.google.firebase.vertexai.internal.util.toPublic import kotlinx.serialization.KSerializer import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -36,9 +35,9 @@ public class PromptFeedback( ) { @Serializable - internal data class InternalPromptFeedback( - val blockReason: BlockReason.InternalBlockReason? = null, - val safetyRatings: List? = null, + internal data class Internal( + val blockReason: BlockReason.Internal? = null, + val safetyRatings: List? = null, val blockReasonMessage: String? = null, ) { @@ -56,15 +55,15 @@ public class PromptFeedback( /** Describes why content was blocked. */ public class BlockReason private constructor(public val name: String, public val ordinal: Int) { - @Serializable(InternalBlockReason.InternalBlockReasonSerializer::class) - internal enum class InternalBlockReason { + @Serializable(Internal.Serializer::class) + internal enum class Internal { UNKNOWN, @SerialName("BLOCKED_REASON_UNSPECIFIED") UNSPECIFIED, SAFETY, OTHER; - internal object InternalBlockReasonSerializer : - KSerializer by FirstOrdinalSerializer(InternalBlockReason::class) + internal object Serializer : + KSerializer by FirstOrdinalSerializer(Internal::class) internal fun toPublic() = when (this) { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt index 62a3e3d3593..351176ea358 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt @@ -33,16 +33,16 @@ public class SafetySetting( internal val method: HarmBlockMethod? = null, ) { internal fun toInternal() = - InternalSafetySetting( + Internal( harmCategory.toInternal(), threshold.toInternal(), method?.toInternal() ) @Serializable - internal data class InternalSafetySetting( - val category: HarmCategory.InternalHarmCategory, - val threshold: HarmBlockThreshold.InternalHarmBlockThreshold, - val method: HarmBlockMethod.InternalHarmBlockMethod? = null, + internal data class Internal( + val category: HarmCategory.Internal, + val threshold: HarmBlockThreshold.Internal, + val method: HarmBlockMethod.Internal? = null, ) } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt index d507b3b009c..5f587effdb8 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt @@ -241,8 +241,8 @@ internal constructor( ) } - internal fun toInternal(): InternalSchema = - InternalSchema( + internal fun toInternal(): Internal = + Internal( type, description, format, @@ -253,15 +253,15 @@ internal constructor( items?.toInternal(), ) @Serializable - internal data class InternalSchema( + internal data class Internal( val type: String, val description: String? = null, val format: String? = null, val nullable: Boolean? = false, val enum: List? = null, - val properties: Map? = null, + val properties: Map? = null, val required: List? = null, - val items: InternalSchema? = null, + val items: Internal? = null, ) } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt index 87d2961c965..48be89e2293 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt @@ -28,12 +28,12 @@ import kotlinx.serialization.json.JsonObject public class Tool internal constructor(internal val functionDeclarations: List?) { internal fun toInternal() = - InternalTool( + Internal( functionDeclarations?.map { it.toInternal() } ?: emptyList() ) @Serializable - internal data class InternalTool( - val functionDeclarations: List? = null, + internal data class Internal( + val functionDeclarations: List? = null, // This is a json object because it is not possible to make a data class with no parameters. val codeExecution: JsonObject? = null, ) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt index 3417000ba6d..d8596a7b6d6 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt @@ -28,18 +28,18 @@ import kotlinx.serialization.Serializable public class ToolConfig(internal val functionCallingConfig: FunctionCallingConfig?) { internal fun toInternal() = - InternalToolConfig( + Internal( functionCallingConfig?.let { - FunctionCallingConfig.InternalFunctionCallingConfig( + FunctionCallingConfig.Internal( when (it.mode) { FunctionCallingConfig.Mode.ANY -> - FunctionCallingConfig.InternalFunctionCallingConfig.Mode.ANY + FunctionCallingConfig.Internal.Mode.ANY FunctionCallingConfig.Mode.AUTO -> - FunctionCallingConfig.InternalFunctionCallingConfig.Mode.AUTO + FunctionCallingConfig.Internal.Mode.AUTO FunctionCallingConfig.Mode.NONE -> - FunctionCallingConfig.InternalFunctionCallingConfig.Mode.NONE + FunctionCallingConfig.Internal.Mode.NONE }, it.allowedFunctionNames ) @@ -47,8 +47,8 @@ public class ToolConfig(internal val functionCallingConfig: FunctionCallingConfi ) @Serializable - internal data class InternalToolConfig( - @SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig.InternalFunctionCallingConfig? + internal data class Internal( + @SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig.Internal? ) } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt index 25f2d5eeddb..232e8b27d11 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt @@ -32,7 +32,7 @@ public class UsageMetadata( ) { @Serializable - internal data class InternalUsageMetadata( + internal data class Internal( val promptTokenCount: Int? = null, val candidatesTokenCount: Int? = null, val totalTokenCount: Int? = null, diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt index c8e95fb74de..288ac4a85a7 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt @@ -18,13 +18,13 @@ package com.google.firebase.vertexai import com.google.firebase.vertexai.common.APIController import com.google.firebase.vertexai.common.JSON -import com.google.firebase.vertexai.type.InternalTextPart import com.google.firebase.vertexai.common.util.doBlocking import com.google.firebase.vertexai.type.Candidate import com.google.firebase.vertexai.type.Content import com.google.firebase.vertexai.type.GenerateContentResponse import com.google.firebase.vertexai.type.RequestOptions import com.google.firebase.vertexai.type.ServerException +import com.google.firebase.vertexai.type.TextPart import com.google.firebase.vertexai.type.content import io.kotest.assertions.json.shouldContainJsonKey import io.kotest.assertions.json.shouldContainJsonKeyValue @@ -129,12 +129,12 @@ internal class GenerativeModelTesting { private fun generateContentResponseAsJsonString(text: String): String { return JSON.encodeToString( - GenerateContentResponse.InternalGenerateContentResponse( + GenerateContentResponse.Internal( listOf( - Candidate.InternalCandidate( - Content.InternalContent( + Candidate.Internal( + Content.Internal( parts = listOf( - InternalTextPart(text) + TextPart.Internal(text) ) ) ) diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/InternalSchemaTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/InternalTests.kt similarity index 99% rename from firebase-vertexai/src/test/java/com/google/firebase/vertexai/InternalSchemaTests.kt rename to firebase-vertexai/src/test/java/com/google/firebase/vertexai/InternalTests.kt index df2132d7707..cc23a9af859 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/InternalSchemaTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/InternalTests.kt @@ -24,7 +24,7 @@ import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json import org.junit.Test -internal class InternalSchemaTests { +internal class InternalTests { @Test fun `basic schema declaration`() { val schemaDeclaration = diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt index 7b7d4d5720c..a26006d44fb 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt @@ -17,7 +17,6 @@ package com.google.firebase.vertexai.common import com.google.firebase.vertexai.BuildConfig -import com.google.firebase.vertexai.type.InternalTextPart import com.google.firebase.vertexai.common.util.commonTest import com.google.firebase.vertexai.common.util.createResponses import com.google.firebase.vertexai.common.util.doBlocking @@ -26,6 +25,7 @@ import com.google.firebase.vertexai.type.Content import com.google.firebase.vertexai.type.CountTokensResponse import com.google.firebase.vertexai.type.FunctionCallingConfig import com.google.firebase.vertexai.type.RequestOptions +import com.google.firebase.vertexai.type.TextPart import com.google.firebase.vertexai.type.Tool import com.google.firebase.vertexai.type.ToolConfig import io.kotest.assertions.json.shouldContainJsonKey @@ -141,7 +141,7 @@ internal class RequestFormatTests { @Test fun `client id header is set correctly in the request`() = doBlocking { - val response = JSON.encodeToString(CountTokensResponse.InternalCountTokensResponse(totalTokens = 10)) + val response = JSON.encodeToString(CountTokensResponse.Internal(totalTokens = 10)) val mockEngine = MockEngine { respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) } @@ -184,11 +184,11 @@ internal class RequestFormatTests { .generateContentStream( GenerateContentRequest( model = "unused", - contents = listOf(Content.InternalContent(parts = listOf(InternalTextPart("Arbitrary")))), + contents = listOf(Content.Internal(parts = listOf(TextPart.Internal("Arbitrary")))), toolConfig = - ToolConfig.InternalToolConfig( - FunctionCallingConfig.InternalFunctionCallingConfig( - mode = FunctionCallingConfig.InternalFunctionCallingConfig.Mode.ANY, + ToolConfig.Internal( + FunctionCallingConfig.Internal( + mode = FunctionCallingConfig.Internal.Mode.ANY, allowedFunctionNames = listOf("allowedFunctionName") ) ) @@ -206,7 +206,7 @@ internal class RequestFormatTests { @Test fun `headers from HeaderProvider are added to the request`() = doBlocking { - val response = JSON.encodeToString(CountTokensResponse.InternalCountTokensResponse(totalTokens = 10)) + val response = JSON.encodeToString(CountTokensResponse.Internal(totalTokens = 10)) val mockEngine = MockEngine { respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) } @@ -238,7 +238,7 @@ internal class RequestFormatTests { @Test fun `headers from HeaderProvider are ignored if timeout`() = doBlocking { - val response = JSON.encodeToString(CountTokensResponse.InternalCountTokensResponse(totalTokens = 10)) + val response = JSON.encodeToString(CountTokensResponse.Internal(totalTokens = 10)) val mockEngine = MockEngine { respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) } @@ -292,8 +292,8 @@ internal class RequestFormatTests { .generateContentStream( GenerateContentRequest( model = "unused", - contents = listOf(Content.InternalContent(parts = listOf(InternalTextPart("Arbitrary")))), - tools = listOf(Tool.InternalTool(codeExecution = JsonObject(emptyMap()))), + contents = listOf(Content.Internal(parts = listOf(TextPart.Internal("Arbitrary")))), + tools = listOf(Tool.Internal(codeExecution = JsonObject(emptyMap()))), ) ) .collect { channel.close() } @@ -352,7 +352,7 @@ internal class ModelNamingTests(private val modelName: String, private val actua internal fun textGenerateContentRequest(prompt: String) = GenerateContentRequest( model = "unused", - contents = listOf(Content.InternalContent(parts = listOf(InternalTextPart(prompt)))), + contents = listOf(Content.Internal(parts = listOf(TextPart.Internal(prompt)))), ) internal fun textCountTokenRequest(prompt: String) = diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt index 11e77460f64..c1ae7d422ce 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt @@ -16,11 +16,11 @@ package com.google.firebase.vertexai.common -import com.google.firebase.vertexai.type.InternalFinishReason -import com.google.firebase.vertexai.type.InternalTextPart import com.google.firebase.vertexai.common.util.goldenStreamingFile import com.google.firebase.vertexai.type.BlockReason +import com.google.firebase.vertexai.type.FinishReason import com.google.firebase.vertexai.type.HarmCategory +import com.google.firebase.vertexai.type.TextPart import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.shouldBe @@ -43,7 +43,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val responseList = responses.toList() responseList.isEmpty() shouldBe false - responseList.first().candidates?.first()?.finishReason shouldBe InternalFinishReason.STOP + responseList.first().candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP responseList.first().candidates?.first()?.content?.parts?.isEmpty() shouldBe false responseList.first().candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } @@ -58,7 +58,7 @@ internal class StreamingSnapshotTests { val responseList = responses.toList() responseList.isEmpty() shouldBe false responseList.forEach { - it.candidates?.first()?.finishReason shouldBe InternalFinishReason.STOP + it.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP it.candidates?.first()?.content?.parts?.isEmpty() shouldBe false it.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } @@ -75,7 +75,7 @@ internal class StreamingSnapshotTests { responseList.isEmpty() shouldBe false responseList.any { it.candidates?.any { - it.safetyRatings?.any { it.category == HarmCategory.InternalHarmCategory.UNKNOWN } ?: false + it.safetyRatings?.any { it.category == HarmCategory.Internal.UNKNOWN } ?: false } ?: false } shouldBe true @@ -91,7 +91,7 @@ internal class StreamingSnapshotTests { val responseList = responses.toList() responseList.isEmpty() shouldBe false - val part = responseList.first().candidates?.first()?.content?.parts?.first() as? InternalTextPart + val part = responseList.first().candidates?.first()?.content?.parts?.first() as? TextPart.Internal part.shouldNotBeNull() part.text shouldContain "\"" } @@ -104,7 +104,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.promptFeedback?.blockReason shouldBe BlockReason.InternalBlockReason.SAFETY + exception.response.promptFeedback?.blockReason shouldBe BlockReason.Internal.SAFETY } } @@ -131,7 +131,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.candidates?.first()?.finishReason shouldBe InternalFinishReason.SAFETY + exception.response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.SAFETY } } @@ -170,7 +170,7 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.candidates?.first()?.finishReason shouldBe InternalFinishReason.RECITATION + exception.response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.RECITATION } } diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt index 8b2a31fe149..ad85d01b544 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt @@ -16,15 +16,15 @@ package com.google.firebase.vertexai.common -import com.google.firebase.vertexai.type.InternalFinishReason -import com.google.firebase.vertexai.type.InternalFunctionCallPart -import com.google.firebase.vertexai.type.InternalTextPart import com.google.firebase.vertexai.common.util.goldenUnaryFile import com.google.firebase.vertexai.common.util.shouldNotBeNullOrEmpty import com.google.firebase.vertexai.type.BlockReason +import com.google.firebase.vertexai.type.FinishReason +import com.google.firebase.vertexai.type.FunctionCallPart import com.google.firebase.vertexai.type.HarmCategory import com.google.firebase.vertexai.type.HarmProbability import com.google.firebase.vertexai.type.HarmSeverity +import com.google.firebase.vertexai.type.TextPart import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.collections.shouldNotBeEmpty import io.kotest.matchers.nulls.shouldNotBeNull @@ -53,7 +53,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe InternalFinishReason.STOP + response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } @@ -67,7 +67,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe InternalFinishReason.STOP + response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP response.candidates?.first()?.content?.parts?.isEmpty() shouldBe false response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false } @@ -81,7 +81,7 @@ internal class UnarySnapshotTests { response.candidates?.isNullOrEmpty() shouldBe false val candidate = response.candidates?.first() - candidate?.safetyRatings?.any { it.category == HarmCategory.InternalHarmCategory.UNKNOWN } shouldBe true + candidate?.safetyRatings?.any { it.category == HarmCategory.Internal.UNKNOWN } shouldBe true } } @@ -94,12 +94,12 @@ internal class UnarySnapshotTests { response.candidates?.isEmpty() shouldBe false response.candidates?.first()?.safetyRatings?.isEmpty() shouldBe false response.candidates?.first()?.safetyRatings?.all { - it.probability == HarmProbability.InternalHarmProbability.NEGLIGIBLE + it.probability == HarmProbability.Internal.NEGLIGIBLE } shouldBe true response.candidates?.first()?.safetyRatings?.all { it.probabilityScore != null } shouldBe true response.candidates?.first()?.safetyRatings?.all { - it.severity == HarmSeverity.InternalHarmSeverity.NEGLIGIBLE + it.severity == HarmSeverity.Internal.NEGLIGIBLE } shouldBe true response.candidates?.first()?.safetyRatings?.all { it.severityScore != null } shouldBe true } @@ -111,7 +111,7 @@ internal class UnarySnapshotTests { withTimeout(testTimeout) { shouldThrow { apiController.generateContent(textGenerateContentRequest("prompt")) - } should { it.response.promptFeedback?.blockReason shouldBe BlockReason.InternalBlockReason.SAFETY } + } should { it.response.promptFeedback?.blockReason shouldBe BlockReason.Internal.SAFETY } } } @@ -153,7 +153,7 @@ internal class UnarySnapshotTests { shouldThrow { apiController.generateContent(textGenerateContentRequest("prompt")) } - exception.response.candidates?.first()?.finishReason shouldBe InternalFinishReason.SAFETY + exception.response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.SAFETY } } @@ -191,7 +191,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe InternalFinishReason.STOP + response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP response.usageMetadata shouldNotBe null response.usageMetadata?.totalTokenCount shouldBe 363 } @@ -204,7 +204,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) response.candidates?.isEmpty() shouldBe false - response.candidates?.first()?.finishReason shouldBe InternalFinishReason.STOP + response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.STOP response.usageMetadata shouldNotBe null response.usageMetadata?.promptTokenCount shouldBe 6 response.usageMetadata?.totalTokenCount shouldBe null @@ -231,7 +231,7 @@ internal class UnarySnapshotTests { response.candidates?.isEmpty() shouldBe false with( - response.candidates?.first()?.content?.parts?.first()?.shouldBeInstanceOf() + response.candidates?.first()?.content?.parts?.first()?.shouldBeInstanceOf() ) { shouldNotBeNull() JSON.decodeFromString>(text).shouldNotBeEmpty() @@ -315,7 +315,7 @@ internal class UnarySnapshotTests { goldenUnaryFile("success-function-call-null.json") { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - val callPart = (response.candidates!!.first().content!!.parts.first() as InternalFunctionCallPart) + val callPart = (response.candidates!!.first().content!!.parts.first() as FunctionCallPart.Internal) callPart.functionCall.args shouldNotBe null callPart.functionCall.args?.get("season") shouldBe null @@ -333,7 +333,7 @@ internal class UnarySnapshotTests { content.let { it.shouldNotBeNull() it.parts.shouldNotBeEmpty() - it.parts.first().shouldBeInstanceOf() + it.parts.first().shouldBeInstanceOf() } callPart.functionCall.args shouldNotBe null @@ -349,7 +349,7 @@ internal class UnarySnapshotTests { val response = apiController.generateContent(textGenerateContentRequest("prompt")) val content = response.candidates.shouldNotBeNullOrEmpty().first().content content.shouldNotBeNull() - val callPart = content.parts.shouldNotBeNullOrEmpty().first() as InternalFunctionCallPart + val callPart = content.parts.shouldNotBeNullOrEmpty().first() as FunctionCallPart.Internal callPart.functionCall.name shouldBe "current_time" callPart.functionCall.args shouldBe null diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt index c46535eb841..750b77f4e24 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt @@ -24,8 +24,8 @@ import com.google.firebase.vertexai.common.JSON import com.google.firebase.vertexai.type.Candidate import com.google.firebase.vertexai.type.Content import com.google.firebase.vertexai.type.GenerateContentResponse -import com.google.firebase.vertexai.type.InternalTextPart import com.google.firebase.vertexai.type.RequestOptions +import com.google.firebase.vertexai.type.TextPart import io.kotest.matchers.collections.shouldNotBeEmpty import io.kotest.matchers.nulls.shouldNotBeNull import io.ktor.client.engine.mock.MockEngine @@ -43,15 +43,15 @@ import kotlinx.serialization.encodeToString private val TEST_CLIENT_ID = "genai-android/test" -internal fun prepareStreamingResponse(response: List): List = +internal fun prepareStreamingResponse(response: List): List = response.map { "data: ${JSON.encodeToString(it)}$SSE_SEPARATOR".toByteArray() } -internal fun prepareResponse(response: GenerateContentResponse.InternalGenerateContentResponse) = +internal fun prepareResponse(response: GenerateContentResponse.Internal) = JSON.encodeToString(response).toByteArray() @OptIn(ExperimentalSerializationApi::class) internal fun createRequest(vararg text: String): GenerateContentRequest { - val contents = text.map { Content.InternalContent(parts = listOf(InternalTextPart(it))) } + val contents = text.map { Content.Internal(parts = listOf(TextPart.Internal(it))) } return GenerateContentRequest("gemini", contents) } @@ -59,19 +59,19 @@ internal fun createRequest(vararg text: String): GenerateContentRequest { internal fun createResponse(text: String) = createResponses(text).single() @OptIn(ExperimentalSerializationApi::class) -internal fun createResponses(vararg text: String): List { +internal fun createResponses(vararg text: String): List { val candidates = text.map { - Candidate.InternalCandidate( - Content.InternalContent( + Candidate.Internal( + Content.Internal( parts = listOf( - InternalTextPart(it) + TextPart.Internal(it) ) ) ) } return candidates.map { - GenerateContentResponse.InternalGenerateContentResponse( + GenerateContentResponse.Internal( candidates = listOf( it ) From 0c0a04b5c4bfc5a5d51a78939861778fb45b1821 Mon Sep 17 00:00:00 2001 From: Emily Ploszaj Date: Thu, 23 Jan 2025 14:05:54 -0600 Subject: [PATCH 3/9] Format --- .../firebase/vertexai/common/APIController.kt | 8 +++- .../firebase/vertexai/type/Candidate.kt | 20 ++++---- .../google/firebase/vertexai/type/Content.kt | 14 +++--- .../vertexai/type/CountTokensResponse.kt | 6 +-- .../vertexai/type/FunctionCallingConfig.kt | 1 - .../vertexai/type/FunctionDeclaration.kt | 3 +- .../firebase/vertexai/type/HarmCategory.kt | 3 +- .../firebase/vertexai/type/HarmProbability.kt | 3 +- .../firebase/vertexai/type/HarmSeverity.kt | 3 +- .../com/google/firebase/vertexai/type/Part.kt | 46 ++++++------------- .../firebase/vertexai/type/PromptFeedback.kt | 9 +--- .../firebase/vertexai/type/SafetySetting.kt | 6 +-- .../google/firebase/vertexai/type/Schema.kt | 1 - .../com/google/firebase/vertexai/type/Tool.kt | 5 +- .../firebase/vertexai/type/ToolConfig.kt | 15 ++---- .../firebase/vertexai/type/UsageMetadata.kt | 1 - .../vertexai/GenerativeModelTesting.kt | 10 +--- .../vertexai/common/APIControllerTests.kt | 10 ++-- .../vertexai/common/StreamingSnapshotTests.kt | 6 ++- .../vertexai/common/UnarySnapshotTests.kt | 10 +++- .../firebase/vertexai/common/util/tests.kt | 24 +++------- 21 files changed, 74 insertions(+), 130 deletions(-) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt index b5dd0b9a0b2..286b8829241 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt @@ -122,7 +122,9 @@ internal constructor( throw FirebaseCommonAIException.from(e) } - fun generateContentStream(request: GenerateContentRequest): Flow = + fun generateContentStream( + request: GenerateContentRequest + ): Flow = client .postStream( "${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:streamGenerateContent?alt=sse" @@ -277,7 +279,9 @@ private suspend fun validateResponse(response: HttpResponse) { throw ServerException(message) } -private fun getServiceDisabledErrorDetailsOrNull(error: GRpcErrorResponse.GRpcError): GRpcErrorResponse.GRpcError.GRpcErrorDetails? { +private fun getServiceDisabledErrorDetailsOrNull( + error: GRpcErrorResponse.GRpcError +): GRpcErrorResponse.GRpcError.GRpcErrorDetails? { return error.details?.firstOrNull { it.reason == "SERVICE_DISABLED" && it.domain == "googleapis.com" } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt index 6ada813b210..54cca4a80b4 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Candidate.kt @@ -17,12 +17,12 @@ package com.google.firebase.vertexai.type import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer +import java.util.Calendar import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.KSerializer import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonNames -import java.util.Calendar /** * A `Candidate` represents a single response generated by the model for a given request. @@ -115,7 +115,9 @@ internal constructor( ) { @Serializable - internal data class Internal @JvmOverloads constructor( + internal data class Internal + @JvmOverloads + constructor( val category: HarmCategory.Internal, val probability: HarmProbability.Internal, val blocked: Boolean? = null, // TODO(): any reason not to default to false? @@ -134,7 +136,6 @@ internal constructor( severityScore = severityScore ) } - } /** @@ -150,8 +151,7 @@ public class CitationMetadata internal constructor(public val citations: List) { - internal fun toPublic() = - CitationMetadata(citationSources.map { it.toPublic() }) + internal fun toPublic() = CitationMetadata(citationSources.map { it.toPublic() }) } } @@ -197,7 +197,8 @@ internal constructor( val calendar = Calendar.getInstance() // Internal `Date.year` uses 0 to represent not specified. We use 1 as default. val year = if (it.year == null || it.year < 1) 1 else it.year - // Internal `Date.month` uses 0 to represent not specified, or is 1-12 as months. The month as + // Internal `Date.month` uses 0 to represent not specified, or is 1-12 as months. The + // month as // expected by [Calendar] is 0-based, so we subtract 1 or use 0 as default. val month = if (it.month == null || it.month < 1) 0 else it.month - 1 // Internal `Date.day` uses 0 to represent not specified. We use 1 as default. @@ -222,8 +223,8 @@ internal constructor( /** 1-based index for month. Must be from 1 to 12, or 0 to specify a year without a month. */ val month: Int? = null, /** - * Day of a month. Must be from 1 to 31 and valid for the year and month, or 0 to specify a year - * by itself or a year and month where the day isn't significant. + * Day of a month. Must be from 1 to 31 and valid for the year and month, or 0 to specify a + * year by itself or a year and month where the day isn't significant. */ val day: Int? = null, ) @@ -248,8 +249,7 @@ public class FinishReason private constructor(public val name: String, public va RECITATION, OTHER; - internal object Serializer : - KSerializer by FirstOrdinalSerializer(Internal::class) + internal object Serializer : KSerializer by FirstOrdinalSerializer(Internal::class) internal fun toPublic() = when (this) { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt index c1a6c43bd1f..6d2285a281d 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt @@ -80,18 +80,16 @@ constructor(public val role: String? = "user", public val parts: List) { public fun build(): Content = Content(role, parts) } - internal fun toInternal() = - Internal( - this.role ?: "user", - this.parts.map { it.toInternal() } - ) + internal fun toInternal() = Internal(this.role ?: "user", this.parts.map { it.toInternal() }) @ExperimentalSerializationApi @Serializable - internal data class Internal(@EncodeDefault val role: String? = "user", val parts: List) { + internal data class Internal( + @EncodeDefault val role: String? = "user", + val parts: List + ) { - internal fun toPublic(): Content = - Content(role, parts.map { it.toPublic() }) + internal fun toPublic(): Content = Content(role, parts.map { it.toPublic() }) } } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt index 4692ed7039a..4c05521ad65 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/CountTokensResponse.kt @@ -40,10 +40,8 @@ public class CountTokensResponse( public operator fun component2(): Int? = totalBillableCharacters @Serializable - internal data class Internal( - val totalTokens: Int, - val totalBillableCharacters: Int? = null - ) : Response { + internal data class Internal(val totalTokens: Int, val totalBillableCharacters: Int? = null) : + Response { internal fun toPublic(): CountTokensResponse { return CountTokensResponse(totalTokens, totalBillableCharacters ?: 0) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt index 5805a7c1800..ee557556bbc 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt @@ -48,7 +48,6 @@ internal constructor( NONE, } - @Serializable internal data class Internal( val mode: Mode, diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt index 6452d92e4b6..8813de18b43 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionDeclaration.kt @@ -61,8 +61,7 @@ public class FunctionDeclaration( internal val schema: Schema = Schema.obj(properties = parameters, optionalProperties = optionalParameters, nullable = false) - internal fun toInternal() = - Internal(name, "", schema.toInternal()) + internal fun toInternal() = Internal(name, "", schema.toInternal()) @Serializable internal data class Internal( diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt index 945febb2689..857e674ed5f 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt @@ -43,8 +43,7 @@ public class HarmCategory private constructor(public val ordinal: Int) { @SerialName("HARM_CATEGORY_DANGEROUS_CONTENT") DANGEROUS_CONTENT, @SerialName("HARM_CATEGORY_CIVIC_INTEGRITY") CIVIC_INTEGRITY; - internal object Serializer : - KSerializer by FirstOrdinalSerializer(Internal::class) + internal object Serializer : KSerializer by FirstOrdinalSerializer(Internal::class) internal fun toPublic() = when (this) { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt index 0e406a58dee..3d13e177819 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmProbability.kt @@ -32,8 +32,7 @@ public class HarmProbability private constructor(public val ordinal: Int) { MEDIUM, HIGH; - internal object Serializer : - KSerializer by FirstOrdinalSerializer(Internal::class) + internal object Serializer : KSerializer by FirstOrdinalSerializer(Internal::class) internal fun toPublic() = when (this) { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt index 21c4066a728..0d0a39f2ac9 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmSeverity.kt @@ -32,8 +32,7 @@ public class HarmSeverity private constructor(public val ordinal: Int) { @SerialName("HARM_SEVERITY_MEDIUM") MEDIUM, @SerialName("HARM_SEVERITY_HIGH") HIGH; - internal object Serializer : - KSerializer by FirstOrdinalSerializer(Internal::class) + internal object Serializer : KSerializer by FirstOrdinalSerializer(Internal::class) internal fun toPublic() = when (this) { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt index a0789644f54..000415d041d 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt @@ -19,6 +19,7 @@ package com.google.firebase.vertexai.type import android.graphics.Bitmap import android.graphics.BitmapFactory import com.google.firebase.vertexai.internal.util.BASE_64_FLAGS +import java.io.ByteArrayOutputStream import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -29,17 +30,14 @@ import kotlinx.serialization.json.JsonNull import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.jsonObject import org.json.JSONObject -import java.io.ByteArrayOutputStream /** Interface representing data sent to and received from requests. */ -public interface Part { -} +public interface Part {} /** Represents text or string based data sent to and received from requests. */ public class TextPart(public val text: String) : Part { - @Serializable - internal data class Internal(val text: String) : InternalPart + @Serializable internal data class Internal(val text: String) : InternalPart } /** @@ -94,11 +92,9 @@ public class FunctionCallPart(public val name: String, public val args: Map(InternalPart::class) { +internal object PartSerializer : + JsonContentPolymorphicSerializer(InternalPart::class) { override fun selectDeserializer(element: JsonElement): DeserializationStrategy { val jsonObject = element.jsonObject return when { @@ -159,10 +155,7 @@ internal fun Part.toInternal(): InternalPart { is TextPart -> TextPart.Internal(text) is ImagePart -> InlineDataPart.Internal( - InlineDataPart.Internal.InlineData( - "image/jpeg", - encodeBitmapToBase64Png(image) - ) + InlineDataPart.Internal.InlineData("image/jpeg", encodeBitmapToBase64Png(image)) ) is InlineDataPart -> InlineDataPart.Internal( @@ -172,23 +165,11 @@ internal fun Part.toInternal(): InternalPart { ) ) is FunctionCallPart -> - FunctionCallPart.Internal( - FunctionCallPart.Internal.FunctionCall( - name, - args - ) - ) + FunctionCallPart.Internal(FunctionCallPart.Internal.FunctionCall(name, args)) is FunctionResponsePart -> - FunctionResponsePart.Internal( - FunctionResponsePart.Internal.FunctionResponse( - name, - response - ) - ) + FunctionResponsePart.Internal(FunctionResponsePart.Internal.FunctionResponse(name, response)) is FileDataPart -> - FileDataPart.Internal( - FileDataPart.Internal.FileData(mimeType = mimeType, fileUri = uri) - ) + FileDataPart.Internal(FileDataPart.Internal.FileData(mimeType = mimeType, fileUri = uri)) else -> throw com.google.firebase.vertexai.type.SerializationException( "The given subclass of Part (${javaClass.simpleName}) is not supported in the serialization yet." @@ -224,8 +205,7 @@ internal fun InternalPart.toPublic(): Part { functionResponse.name, functionResponse.response, ) - is FileDataPart.Internal -> - FileDataPart(fileData.mimeType, fileData.fileUri) + is FileDataPart.Internal -> FileDataPart(fileData.mimeType, fileData.fileUri) else -> throw com.google.firebase.vertexai.type.SerializationException( "Unsupported part type \"${javaClass.simpleName}\" provided. This model may not be supported by this SDK." diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt index 16ef10b9488..5f0e0edd017 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/PromptFeedback.kt @@ -43,11 +43,7 @@ public class PromptFeedback( internal fun toPublic(): PromptFeedback { val safetyRatings = safetyRatings?.map { it.toPublic() }.orEmpty() - return PromptFeedback( - blockReason?.toPublic(), - safetyRatings, - blockReasonMessage - ) + return PromptFeedback(blockReason?.toPublic(), safetyRatings, blockReasonMessage) } } } @@ -62,8 +58,7 @@ public class BlockReason private constructor(public val name: String, public val SAFETY, OTHER; - internal object Serializer : - KSerializer by FirstOrdinalSerializer(Internal::class) + internal object Serializer : KSerializer by FirstOrdinalSerializer(Internal::class) internal fun toPublic() = when (this) { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt index 351176ea358..8095c42c532 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/SafetySetting.kt @@ -33,11 +33,7 @@ public class SafetySetting( internal val method: HarmBlockMethod? = null, ) { internal fun toInternal() = - Internal( - harmCategory.toInternal(), - threshold.toInternal(), - method?.toInternal() - ) + Internal(harmCategory.toInternal(), threshold.toInternal(), method?.toInternal()) @Serializable internal data class Internal( diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt index 5f587effdb8..869d83b0eb9 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Schema.kt @@ -264,4 +264,3 @@ internal constructor( val items: Internal? = null, ) } - diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt index 48be89e2293..e62e02f55b1 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Tool.kt @@ -27,10 +27,7 @@ import kotlinx.serialization.json.JsonObject */ public class Tool internal constructor(internal val functionDeclarations: List?) { - internal fun toInternal() = - Internal( - functionDeclarations?.map { it.toInternal() } ?: emptyList() - ) + internal fun toInternal() = Internal(functionDeclarations?.map { it.toInternal() } ?: emptyList()) @Serializable internal data class Internal( val functionDeclarations: List? = null, diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt index d8596a7b6d6..99769ed46b6 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt @@ -32,14 +32,9 @@ public class ToolConfig(internal val functionCallingConfig: FunctionCallingConfi functionCallingConfig?.let { FunctionCallingConfig.Internal( when (it.mode) { - FunctionCallingConfig.Mode.ANY -> - FunctionCallingConfig.Internal.Mode.ANY - - FunctionCallingConfig.Mode.AUTO -> - FunctionCallingConfig.Internal.Mode.AUTO - - FunctionCallingConfig.Mode.NONE -> - FunctionCallingConfig.Internal.Mode.NONE + FunctionCallingConfig.Mode.ANY -> FunctionCallingConfig.Internal.Mode.ANY + FunctionCallingConfig.Mode.AUTO -> FunctionCallingConfig.Internal.Mode.AUTO + FunctionCallingConfig.Mode.NONE -> FunctionCallingConfig.Internal.Mode.NONE }, it.allowedFunctionNames ) @@ -48,7 +43,7 @@ public class ToolConfig(internal val functionCallingConfig: FunctionCallingConfi @Serializable internal data class Internal( - @SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig.Internal? + @SerialName("function_calling_config") + val functionCallingConfig: FunctionCallingConfig.Internal? ) } - diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt index 232e8b27d11..54f5cbd89b7 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/UsageMetadata.kt @@ -42,4 +42,3 @@ public class UsageMetadata( UsageMetadata(promptTokenCount ?: 0, candidatesTokenCount ?: 0, totalTokenCount ?: 0) } } - diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt index 288ac4a85a7..67d41c9b5d6 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt @@ -130,15 +130,7 @@ internal class GenerativeModelTesting { private fun generateContentResponseAsJsonString(text: String): String { return JSON.encodeToString( GenerateContentResponse.Internal( - listOf( - Candidate.Internal( - Content.Internal( - parts = listOf( - TextPart.Internal(text) - ) - ) - ) - ) + listOf(Candidate.Internal(Content.Internal(parts = listOf(TextPart.Internal(text))))) ) ) } diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt index a26006d44fb..463dbe773f7 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt @@ -186,12 +186,12 @@ internal class RequestFormatTests { model = "unused", contents = listOf(Content.Internal(parts = listOf(TextPart.Internal("Arbitrary")))), toolConfig = - ToolConfig.Internal( - FunctionCallingConfig.Internal( - mode = FunctionCallingConfig.Internal.Mode.ANY, - allowedFunctionNames = listOf("allowedFunctionName") + ToolConfig.Internal( + FunctionCallingConfig.Internal( + mode = FunctionCallingConfig.Internal.Mode.ANY, + allowedFunctionNames = listOf("allowedFunctionName") + ) ) - ) ), ) .collect { channel.close() } diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt index c1ae7d422ce..2d29ad38ba7 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/StreamingSnapshotTests.kt @@ -91,7 +91,8 @@ internal class StreamingSnapshotTests { val responseList = responses.toList() responseList.isEmpty() shouldBe false - val part = responseList.first().candidates?.first()?.content?.parts?.first() as? TextPart.Internal + val part = + responseList.first().candidates?.first()?.content?.parts?.first() as? TextPart.Internal part.shouldNotBeNull() part.text shouldContain "\"" } @@ -170,7 +171,8 @@ internal class StreamingSnapshotTests { withTimeout(testTimeout) { val exception = shouldThrow { responses.collect() } - exception.response.candidates?.first()?.finishReason shouldBe FinishReason.Internal.RECITATION + exception.response.candidates?.first()?.finishReason shouldBe + FinishReason.Internal.RECITATION } } diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt index ad85d01b544..33ebdda5322 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/UnarySnapshotTests.kt @@ -231,7 +231,12 @@ internal class UnarySnapshotTests { response.candidates?.isEmpty() shouldBe false with( - response.candidates?.first()?.content?.parts?.first()?.shouldBeInstanceOf() + response.candidates + ?.first() + ?.content + ?.parts + ?.first() + ?.shouldBeInstanceOf() ) { shouldNotBeNull() JSON.decodeFromString>(text).shouldNotBeEmpty() @@ -315,7 +320,8 @@ internal class UnarySnapshotTests { goldenUnaryFile("success-function-call-null.json") { withTimeout(testTimeout) { val response = apiController.generateContent(textGenerateContentRequest("prompt")) - val callPart = (response.candidates!!.first().content!!.parts.first() as FunctionCallPart.Internal) + val callPart = + (response.candidates!!.first().content!!.parts.first() as FunctionCallPart.Internal) callPart.functionCall.args shouldNotBe null callPart.functionCall.args?.get("season") shouldBe null diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt index 750b77f4e24..5e52b1827b0 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt @@ -43,8 +43,9 @@ import kotlinx.serialization.encodeToString private val TEST_CLIENT_ID = "genai-android/test" -internal fun prepareStreamingResponse(response: List): List = - response.map { "data: ${JSON.encodeToString(it)}$SSE_SEPARATOR".toByteArray() } +internal fun prepareStreamingResponse( + response: List +): List = response.map { "data: ${JSON.encodeToString(it)}$SSE_SEPARATOR".toByteArray() } internal fun prepareResponse(response: GenerateContentResponse.Internal) = JSON.encodeToString(response).toByteArray() @@ -60,23 +61,10 @@ internal fun createResponse(text: String) = createResponses(text).single() @OptIn(ExperimentalSerializationApi::class) internal fun createResponses(vararg text: String): List { - val candidates = text.map { - Candidate.Internal( - Content.Internal( - parts = listOf( - TextPart.Internal(it) - ) - ) - ) - } + val candidates = + text.map { Candidate.Internal(Content.Internal(parts = listOf(TextPart.Internal(it)))) } - return candidates.map { - GenerateContentResponse.Internal( - candidates = listOf( - it - ) - ) - } + return candidates.map { GenerateContentResponse.Internal(candidates = listOf(it)) } } /** From c75b68ece4a702ed7adf61d1b2c7a184969b5a23 Mon Sep 17 00:00:00 2001 From: Emily Ploszaj Date: Fri, 24 Jan 2025 12:05:49 -0600 Subject: [PATCH 4/9] Strip redundant serial names --- .../vertexai/type/FunctionCallingConfig.kt | 2 +- .../firebase/vertexai/type/GenerationConfig.kt | 18 +++++++++--------- .../firebase/vertexai/type/ToolConfig.kt | 1 - 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt index ee557556bbc..9461ac42b5e 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt @@ -51,7 +51,7 @@ internal constructor( @Serializable internal data class Internal( val mode: Mode, - @SerialName("allowed_function_names") val allowedFunctionNames: List? = null + val allowedFunctionNames: List? = null ) { @Serializable enum class Mode { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt index 4abec8a260d..0ffc76423cf 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt @@ -162,15 +162,15 @@ private constructor( @Serializable internal data class Internal( val temperature: Float?, - @SerialName("top_p") val topP: Float?, - @SerialName("top_k") val topK: Int?, - @SerialName("candidate_count") val candidateCount: Int?, - @SerialName("max_output_tokens") val maxOutputTokens: Int?, - @SerialName("stop_sequences") val stopSequences: List?, - @SerialName("response_mime_type") val responseMimeType: String? = null, - @SerialName("presence_penalty") val presencePenalty: Float? = null, - @SerialName("frequency_penalty") val frequencyPenalty: Float? = null, - @SerialName("response_schema") val responseSchema: Schema.Internal? = null, + val topP: Float?, + val topK: Int?, + val candidateCount: Int?, + val maxOutputTokens: Int?, + val stopSequences: List?, + val responseMimeType: String? = null, + val presencePenalty: Float? = null, + val frequencyPenalty: Float? = null, + val responseSchema: Schema.Internal? = null, ) public companion object { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt index 99769ed46b6..b6b3d3d04b7 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt @@ -43,7 +43,6 @@ public class ToolConfig(internal val functionCallingConfig: FunctionCallingConfi @Serializable internal data class Internal( - @SerialName("function_calling_config") val functionCallingConfig: FunctionCallingConfig.Internal? ) } From 143efecb4c623f8243c007da1cc0348e90b6071a Mon Sep 17 00:00:00 2001 From: Emily Ploszaj Date: Fri, 24 Jan 2025 14:28:03 -0600 Subject: [PATCH 5/9] Format --- .../google/firebase/vertexai/type/FunctionCallingConfig.kt | 5 +---- .../com/google/firebase/vertexai/type/GenerationConfig.kt | 1 - .../kotlin/com/google/firebase/vertexai/type/ToolConfig.kt | 5 +---- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt index 9461ac42b5e..147e90ee440 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt @@ -49,10 +49,7 @@ internal constructor( } @Serializable - internal data class Internal( - val mode: Mode, - val allowedFunctionNames: List? = null - ) { + internal data class Internal(val mode: Mode, val allowedFunctionNames: List? = null) { @Serializable enum class Mode { @SerialName("MODE_UNSPECIFIED") UNSPECIFIED, diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt index 0ffc76423cf..828b2b32447 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt @@ -16,7 +16,6 @@ package com.google.firebase.vertexai.type -import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable /** diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt index b6b3d3d04b7..479fd642a9d 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt @@ -16,7 +16,6 @@ package com.google.firebase.vertexai.type -import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable /** @@ -42,7 +41,5 @@ public class ToolConfig(internal val functionCallingConfig: FunctionCallingConfi ) @Serializable - internal data class Internal( - val functionCallingConfig: FunctionCallingConfig.Internal? - ) + internal data class Internal(val functionCallingConfig: FunctionCallingConfig.Internal?) } From b748629971cac6889c7c665fe2a94dc67269edc3 Mon Sep 17 00:00:00 2001 From: Emily Ploszaj Date: Wed, 29 Jan 2025 10:50:17 -0600 Subject: [PATCH 6/9] Review changes --- firebase-vertexai/consumer-rules.pro | 1 + .../firebase/vertexai/GenerativeModel.kt | 2 - .../firebase/vertexai/common/Exceptions.kt | 15 +++++++ .../vertexai/internal/util/conversions.kt | 41 ------------------- .../firebase/vertexai/type/Exceptions.kt | 1 - .../firebase/vertexai/type/HarmBlockMethod.kt | 2 +- .../vertexai/type/HarmBlockThreshold.kt | 2 +- .../firebase/vertexai/type/HarmCategory.kt | 2 +- .../com/google/firebase/vertexai/type/Part.kt | 3 +- .../com/google/firebase/vertexai/type/Type.kt | 7 ++++ .../{InternalTests.kt => SchemaTests.kt} | 3 +- .../vertexai/common/EnumUpdateTests.kt | 1 - 12 files changed, 29 insertions(+), 51 deletions(-) delete mode 100644 firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt rename firebase-vertexai/src/test/java/com/google/firebase/vertexai/{InternalTests.kt => SchemaTests.kt} (98%) diff --git a/firebase-vertexai/consumer-rules.pro b/firebase-vertexai/consumer-rules.pro index 7947f53cd58..f328794a748 100644 --- a/firebase-vertexai/consumer-rules.pro +++ b/firebase-vertexai/consumer-rules.pro @@ -20,4 +20,5 @@ # hide the original source file name. #-renamesourcefileattribute SourceFile +-keep class com.google.firebase.vertexai.type.** { *; } -keep class com.google.firebase.vertexai.common.** { *; } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt index 45c6aa9bce4..c2b9fcfd2f9 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt @@ -24,8 +24,6 @@ import com.google.firebase.vertexai.common.APIController import com.google.firebase.vertexai.common.CountTokensRequest import com.google.firebase.vertexai.common.GenerateContentRequest import com.google.firebase.vertexai.common.HeaderProvider -import com.google.firebase.vertexai.internal.util.toInternal -import com.google.firebase.vertexai.internal.util.toPublic import com.google.firebase.vertexai.type.Content import com.google.firebase.vertexai.type.CountTokensResponse import com.google.firebase.vertexai.type.FinishReason diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt index fd38b9813d1..7567c384618 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt @@ -126,3 +126,18 @@ internal class ServiceDisabledException(message: String, cause: Throwable? = nul /** Catch all case for exceptions not explicitly expected. */ internal class UnknownException(message: String, cause: Throwable? = null) : FirebaseCommonAIException(message, cause) + +internal fun makeMissingCaseException( + source: String, + ordinal: Int +): com.google.firebase.vertexai.type.SerializationException { + return com.google.firebase.vertexai.type.SerializationException( + """ + |Missing case for a $source: $ordinal + |This error indicates that one of the `toInternal` conversions needs updating. + |If you're a developer seeing this exception, please file an issue on our GitHub repo: + |https://github.com/firebase/firebase-android-sdk + """ + .trimMargin() + ) +} diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt deleted file mode 100644 index e24b7f08f4e..00000000000 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright 2023 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.firebase.vertexai.internal.util - -import android.util.Base64 -import com.google.firebase.vertexai.type.SerializationException -import kotlinx.serialization.json.Json -import kotlinx.serialization.json.JsonObject -import org.json.JSONObject - -internal const val BASE_64_FLAGS = Base64.NO_WRAP - -internal fun makeMissingCaseException(source: String, ordinal: Int): SerializationException { - return SerializationException( - """ - |Missing case for a $source: $ordinal - |This error indicates that one of the `toInternal` conversions needs updating. - |If you're a developer seeing this exception, please file an issue on our GitHub repo: - |https://github.com/firebase/firebase-android-sdk - """ - .trimMargin() - ) -} - -internal fun JSONObject.toInternal() = Json.decodeFromString(toString()) - -internal fun JsonObject.toPublic() = JSONObject(toString()) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt index a3bd95e15ab..4f4ca954f36 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt @@ -18,7 +18,6 @@ package com.google.firebase.vertexai.type import com.google.firebase.vertexai.FirebaseVertexAI import com.google.firebase.vertexai.common.FirebaseCommonAIException -import com.google.firebase.vertexai.internal.util.toPublic import kotlinx.coroutines.TimeoutCancellationException /** Parent class for any errors that occur from the [FirebaseVertexAI] SDK. */ diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt index d9acb7a28f6..1bd16949b20 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockMethod.kt @@ -16,7 +16,7 @@ package com.google.firebase.vertexai.type -import com.google.firebase.vertexai.internal.util.makeMissingCaseException +import com.google.firebase.vertexai.common.makeMissingCaseException import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt index 73e23be2341..1b3233bda2a 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmBlockThreshold.kt @@ -16,7 +16,7 @@ package com.google.firebase.vertexai.type -import com.google.firebase.vertexai.internal.util.makeMissingCaseException +import com.google.firebase.vertexai.common.makeMissingCaseException import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt index 857e674ed5f..2429688b02b 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/HarmCategory.kt @@ -16,8 +16,8 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.common.makeMissingCaseException import com.google.firebase.vertexai.common.util.FirstOrdinalSerializer -import com.google.firebase.vertexai.internal.util.makeMissingCaseException import kotlinx.serialization.KSerializer import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt index 000415d041d..cfffe6885b0 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Part.kt @@ -18,7 +18,6 @@ package com.google.firebase.vertexai.type import android.graphics.Bitmap import android.graphics.BitmapFactory -import com.google.firebase.vertexai.internal.util.BASE_64_FLAGS import java.io.ByteArrayOutputStream import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.SerialName @@ -133,6 +132,8 @@ public fun Part.asFileDataOrNull(): FileDataPart? = this as? FileDataPart internal typealias Base64 = String +internal const val BASE_64_FLAGS = android.util.Base64.NO_WRAP + @Serializable(PartSerializer::class) internal sealed interface InternalPart internal object PartSerializer : diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Type.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Type.kt index f58f289cdbb..a35000c7da5 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Type.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Type.kt @@ -17,6 +17,9 @@ package com.google.firebase.vertexai.type import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import org.json.JSONObject internal sealed interface Response @@ -38,3 +41,7 @@ internal data class GRpcErrorResponse(val error: GRpcError) : Response { ) } } + +internal fun JSONObject.toInternal() = Json.decodeFromString(toString()) + +internal fun JsonObject.toPublic() = JSONObject(toString()) diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/InternalTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/SchemaTests.kt similarity index 98% rename from firebase-vertexai/src/test/java/com/google/firebase/vertexai/InternalTests.kt rename to firebase-vertexai/src/test/java/com/google/firebase/vertexai/SchemaTests.kt index cc23a9af859..4701d516ff5 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/InternalTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/SchemaTests.kt @@ -16,7 +16,6 @@ package com.google.firebase.vertexai -import com.google.firebase.vertexai.internal.util.toInternal import com.google.firebase.vertexai.type.Schema import com.google.firebase.vertexai.type.StringFormat import io.kotest.assertions.json.shouldEqualJson @@ -24,7 +23,7 @@ import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json import org.junit.Test -internal class InternalTests { +internal class SchemaTests { @Test fun `basic schema declaration`() { val schemaDeclaration = diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/EnumUpdateTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/EnumUpdateTests.kt index ddee4dbabf1..769adbd4cd8 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/EnumUpdateTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/EnumUpdateTests.kt @@ -16,7 +16,6 @@ package com.google.firebase.vertexai.common -import com.google.firebase.vertexai.internal.util.toInternal import com.google.firebase.vertexai.type.HarmBlockMethod import com.google.firebase.vertexai.type.HarmBlockThreshold import com.google.firebase.vertexai.type.HarmCategory From 1c5d1a723e3771af9b41417224aac7a8609252d9 Mon Sep 17 00:00:00 2001 From: Emily Ploszaj Date: Wed, 29 Jan 2025 10:52:41 -0600 Subject: [PATCH 7/9] Revert "Strip redundant serial names" This reverts commit c75b68ece4a702ed7adf61d1b2c7a184969b5a23. --- .../vertexai/type/FunctionCallingConfig.kt | 5 ++++- .../firebase/vertexai/type/GenerationConfig.kt | 18 +++++++++--------- .../firebase/vertexai/type/ToolConfig.kt | 6 +++++- 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt index 147e90ee440..ee557556bbc 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/FunctionCallingConfig.kt @@ -49,7 +49,10 @@ internal constructor( } @Serializable - internal data class Internal(val mode: Mode, val allowedFunctionNames: List? = null) { + internal data class Internal( + val mode: Mode, + @SerialName("allowed_function_names") val allowedFunctionNames: List? = null + ) { @Serializable enum class Mode { @SerialName("MODE_UNSPECIFIED") UNSPECIFIED, diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt index 828b2b32447..af27c710fd9 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt @@ -161,15 +161,15 @@ private constructor( @Serializable internal data class Internal( val temperature: Float?, - val topP: Float?, - val topK: Int?, - val candidateCount: Int?, - val maxOutputTokens: Int?, - val stopSequences: List?, - val responseMimeType: String? = null, - val presencePenalty: Float? = null, - val frequencyPenalty: Float? = null, - val responseSchema: Schema.Internal? = null, + @SerialName("top_p") val topP: Float?, + @SerialName("top_k") val topK: Int?, + @SerialName("candidate_count") val candidateCount: Int?, + @SerialName("max_output_tokens") val maxOutputTokens: Int?, + @SerialName("stop_sequences") val stopSequences: List?, + @SerialName("response_mime_type") val responseMimeType: String? = null, + @SerialName("presence_penalty") val presencePenalty: Float? = null, + @SerialName("frequency_penalty") val frequencyPenalty: Float? = null, + @SerialName("response_schema") val responseSchema: Schema.Internal? = null, ) public companion object { diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt index 479fd642a9d..99769ed46b6 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ToolConfig.kt @@ -16,6 +16,7 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable /** @@ -41,5 +42,8 @@ public class ToolConfig(internal val functionCallingConfig: FunctionCallingConfi ) @Serializable - internal data class Internal(val functionCallingConfig: FunctionCallingConfig.Internal?) + internal data class Internal( + @SerialName("function_calling_config") + val functionCallingConfig: FunctionCallingConfig.Internal? + ) } From f772112bcd76691b1d3b9421b88f45d65ad1eecd Mon Sep 17 00:00:00 2001 From: Emily Ploszaj Date: Wed, 29 Jan 2025 13:56:22 -0600 Subject: [PATCH 8/9] Merge --- .../com/google/firebase/vertexai/type/Content.kt | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt index 6d2285a281d..241d0becfe6 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Content.kt @@ -88,8 +88,14 @@ constructor(public val role: String? = "user", public val parts: List) { @EncodeDefault val role: String? = "user", val parts: List ) { - - internal fun toPublic(): Content = Content(role, parts.map { it.toPublic() }) + internal fun toPublic(): Content { + val returnedParts = + parts.map { it.toPublic() }.filterNot { it is TextPart && it.text.isEmpty() } + // If all returned parts were text and empty, we coalesce them into a single one-character + // string + // part so the backend doesn't fail if we send this back as part of a multi-turn interaction. + return Content(role, returnedParts.ifEmpty { listOf(TextPart(" ")) }) + } } } From f618e009228b6e3f9f1e8aa606bcc0b42be30bc2 Mon Sep 17 00:00:00 2001 From: Emily Ploszaj Date: Wed, 29 Jan 2025 14:02:32 -0600 Subject: [PATCH 9/9] Resolve compilation --- .../kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt | 1 + 1 file changed, 1 insertion(+) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt index af27c710fd9..4abec8a260d 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt @@ -16,6 +16,7 @@ package com.google.firebase.vertexai.type +import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable /**