diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt index 4abec8a260d..96685527f2d 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt @@ -16,6 +16,7 @@ package com.google.firebase.vertexai.type +import com.google.firebase.vertexai.type.ResponseModality // Added import import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -88,6 +89,7 @@ private constructor( internal val stopSequences: List?, internal val responseMimeType: String?, internal val responseSchema: Schema?, + internal val responseModalities: List?, // Added property ) { /** @@ -128,6 +130,7 @@ private constructor( @JvmField public var stopSequences: List? = null @JvmField public var responseMimeType: String? = null @JvmField public var responseSchema: Schema? = null + @JvmField public var responseModalities: List? = null // Added property /** Create a new [GenerationConfig] with the attached arguments. */ public fun build(): GenerationConfig = @@ -142,6 +145,7 @@ private constructor( frequencyPenalty = frequencyPenalty, responseMimeType = responseMimeType, responseSchema = responseSchema, + responseModalities = responseModalities, // Added property ) } @@ -156,7 +160,9 @@ private constructor( frequencyPenalty = frequencyPenalty, presencePenalty = presencePenalty, responseMimeType = responseMimeType, - responseSchema = responseSchema?.toInternal() + responseSchema = responseSchema?.toInternal(), + // Pass the responseModalities to the Internal class constructor + responseModalities = this.responseModalities ) @Serializable @@ -171,6 +177,7 @@ private constructor( @SerialName("presence_penalty") val presencePenalty: Float? = null, @SerialName("frequency_penalty") val frequencyPenalty: Float? = null, @SerialName("response_schema") val responseSchema: Schema.Internal? = null, + @SerialName("response_modalities") val responseModalities: List? = null, // Added property ) public companion object { diff --git a/firebase-vertexai/src/test/kotlin/com/google/firebase/vertexai/type/GenerationConfigTests.kt b/firebase-vertexai/src/test/kotlin/com/google/firebase/vertexai/type/GenerationConfigTests.kt new file mode 100644 index 00000000000..4f36841fa85 --- /dev/null +++ b/firebase-vertexai/src/test/kotlin/com/google/firebase/vertexai/type/GenerationConfigTests.kt @@ -0,0 +1,82 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.firebase.vertexai.type + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive + +class GenerationConfigTests { + + @Test + fun `serialization omits null responseModalities`() { + val config = generationConfig { /* No responseModalities set */ } + val internalConfig = config.toInternal() + val jsonString = Json.encodeToString(internalConfig) + val jsonObject = Json.parseToJsonElement(jsonString).jsonObject + + assertFalse(jsonObject.containsKey("response_modalities")) + } + + @Test + fun `serialization includes all responseModalities`() { + val config = generationConfig { + responseModalities = listOf(ResponseModality.TEXT, ResponseModality.IMAGE, ResponseModality.AUDIO) + } + val internalConfig = config.toInternal() + val jsonString = Json.encodeToString(internalConfig) + val jsonObject = Json.parseToJsonElement(jsonString).jsonObject + + // Assert the JSON output contains the correct array for response_modalities + assertEquals( + """["TEXT","IMAGE","AUDIO"]""", + jsonObject["response_modalities"].toString() + ) + } + + @Test + fun `serialization includes some responseModalities`() { + val config = generationConfig { + responseModalities = listOf(ResponseModality.TEXT, ResponseModality.IMAGE) + } + val internalConfig = config.toInternal() + val jsonString = Json.encodeToString(internalConfig) + val jsonObject = Json.parseToJsonElement(jsonString).jsonObject + + // Assert the JSON output contains the correct array for response_modalities + assertEquals( + """["TEXT","IMAGE"]""", + jsonObject["response_modalities"].toString() + ) + } + + // TODO: Add tests for other properties as well + @Test + fun `serialization includes temperature`() { + val temp = 0.8f + val config = generationConfig { temperature = temp } + val internalConfig = config.toInternal() + val jsonString = Json.encodeToString(internalConfig) + val jsonObject = Json.parseToJsonElement(jsonString).jsonObject + + assertEquals(temp, jsonObject["temperature"]?.jsonPrimitive?.content?.toFloat()) + } +}