Skip to content

Commit 0e6ba3e

Browse files
feat: Add responseModalities to GenerationConfig
Adds an optional `responseModalities` property to the `GenerationConfig` object in the Vertex AI SDK. This allows specifying the desired modalities (TEXT, IMAGE, AUDIO) for the response. Includes updates to the `GenerationConfig.Builder` and internal serialization logic. Adds unit tests to verify: - `response_modalities` field is omitted when null. - Correct serialization for all modalities ([TEXT, IMAGE, AUDIO]). - Correct serialization for a subset of modalities ([TEXT, IMAGE]).
1 parent 51b4a1c commit 0e6ba3e

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/GenerationConfig.kt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.firebase.vertexai.type
1818

19+
import com.google.firebase.vertexai.type.ResponseModality // Added import
1920
import kotlinx.serialization.SerialName
2021
import kotlinx.serialization.Serializable
2122

@@ -88,6 +89,7 @@ private constructor(
8889
internal val stopSequences: List<String>?,
8990
internal val responseMimeType: String?,
9091
internal val responseSchema: Schema?,
92+
internal val responseModalities: List<ResponseModality>?, // Added property
9193
) {
9294

9395
/**
@@ -128,6 +130,7 @@ private constructor(
128130
@JvmField public var stopSequences: List<String>? = null
129131
@JvmField public var responseMimeType: String? = null
130132
@JvmField public var responseSchema: Schema? = null
133+
@JvmField public var responseModalities: List<ResponseModality>? = null // Added property
131134

132135
/** Create a new [GenerationConfig] with the attached arguments. */
133136
public fun build(): GenerationConfig =
@@ -142,6 +145,7 @@ private constructor(
142145
frequencyPenalty = frequencyPenalty,
143146
responseMimeType = responseMimeType,
144147
responseSchema = responseSchema,
148+
responseModalities = responseModalities, // Added property
145149
)
146150
}
147151

@@ -156,7 +160,9 @@ private constructor(
156160
frequencyPenalty = frequencyPenalty,
157161
presencePenalty = presencePenalty,
158162
responseMimeType = responseMimeType,
159-
responseSchema = responseSchema?.toInternal()
163+
responseSchema = responseSchema?.toInternal(),
164+
// Pass the responseModalities to the Internal class constructor
165+
responseModalities = this.responseModalities
160166
)
161167

162168
@Serializable
@@ -171,6 +177,7 @@ private constructor(
171177
@SerialName("presence_penalty") val presencePenalty: Float? = null,
172178
@SerialName("frequency_penalty") val frequencyPenalty: Float? = null,
173179
@SerialName("response_schema") val responseSchema: Schema.Internal? = null,
180+
@SerialName("response_modalities") val responseModalities: List<ResponseModality>? = null, // Added property
174181
)
175182

176183
public companion object {
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.firebase.vertexai.type
18+
19+
import kotlin.test.Test
20+
import kotlin.test.assertEquals
21+
import kotlin.test.assertFalse
22+
import kotlinx.serialization.encodeToString
23+
import kotlinx.serialization.json.Json
24+
import kotlinx.serialization.json.jsonObject
25+
import kotlinx.serialization.json.jsonPrimitive
26+
27+
class GenerationConfigTests {
28+
29+
@Test
30+
fun `serialization omits null responseModalities`() {
31+
val config = generationConfig { /* No responseModalities set */ }
32+
val internalConfig = config.toInternal()
33+
val jsonString = Json.encodeToString(internalConfig)
34+
val jsonObject = Json.parseToJsonElement(jsonString).jsonObject
35+
36+
assertFalse(jsonObject.containsKey("response_modalities"))
37+
}
38+
39+
@Test
40+
fun `serialization includes all responseModalities`() {
41+
val config = generationConfig {
42+
responseModalities = listOf(ResponseModality.TEXT, ResponseModality.IMAGE, ResponseModality.AUDIO)
43+
}
44+
val internalConfig = config.toInternal()
45+
val jsonString = Json.encodeToString(internalConfig)
46+
val jsonObject = Json.parseToJsonElement(jsonString).jsonObject
47+
48+
// Assert the JSON output contains the correct array for response_modalities
49+
assertEquals(
50+
"""["TEXT","IMAGE","AUDIO"]""",
51+
jsonObject["response_modalities"].toString()
52+
)
53+
}
54+
55+
@Test
56+
fun `serialization includes some responseModalities`() {
57+
val config = generationConfig {
58+
responseModalities = listOf(ResponseModality.TEXT, ResponseModality.IMAGE)
59+
}
60+
val internalConfig = config.toInternal()
61+
val jsonString = Json.encodeToString(internalConfig)
62+
val jsonObject = Json.parseToJsonElement(jsonString).jsonObject
63+
64+
// Assert the JSON output contains the correct array for response_modalities
65+
assertEquals(
66+
"""["TEXT","IMAGE"]""",
67+
jsonObject["response_modalities"].toString()
68+
)
69+
}
70+
71+
// TODO: Add tests for other properties as well
72+
@Test
73+
fun `serialization includes temperature`() {
74+
val temp = 0.8f
75+
val config = generationConfig { temperature = temp }
76+
val internalConfig = config.toInternal()
77+
val jsonString = Json.encodeToString(internalConfig)
78+
val jsonObject = Json.parseToJsonElement(jsonString).jsonObject
79+
80+
assertEquals(temp, jsonObject["temperature"]?.jsonPrimitive?.content?.toFloat())
81+
}
82+
}

0 commit comments

Comments
 (0)