Skip to content
This repository was archived by the owner on Dec 16, 2025. It is now read-only.

Commit 87fffa9

Browse files
rlazodaymxn
andauthored
Merge dev branch (#57)
Co-authored-by: Daymon <17409137+daymxn@users.noreply.github.com>
1 parent dc21ffd commit 87fffa9

File tree

9 files changed

+178
-27
lines changed

9 files changed

+178
-27
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"type":"MINOR","changes":["Add RequestOptions; configuration points for backend implementation details such as api version and timeout."]}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"type":"MAJOR","changes":["Support a general model naming schema"]}

generativeai/build.gradle.kts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ dependencies {
8383

8484
implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.5.1")
8585
implementation("androidx.core:core-ktx:1.12.0")
86-
implementation("org.slf4j:slf4j-android:1.7.36")
86+
implementation("org.slf4j:slf4j-nop:2.0.9")
8787
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-android:1.7.3")
8888
implementation("org.jetbrains.kotlinx:kotlinx-coroutines-reactive:1.7.3")
8989
implementation("org.reactivestreams:reactive-streams:1.0.3")
@@ -94,6 +94,7 @@ dependencies {
9494
testImplementation("junit:junit:4.13.2")
9595
testImplementation("io.kotest:kotest-assertions-core:4.0.7")
9696
testImplementation("io.kotest:kotest-assertions-jvm:4.0.7")
97+
testImplementation("io.kotest:kotest-assertions-json:4.0.7")
9798
testImplementation("io.ktor:ktor-client-mock:$ktorVersion")
9899
androidTestImplementation("androidx.test.ext:junit:1.1.5")
99100
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")

generativeai/src/main/java/com/google/ai/client/generativeai/GenerativeModel.kt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import com.google.ai.client.generativeai.type.GenerateContentResponse
2929
import com.google.ai.client.generativeai.type.GenerationConfig
3030
import com.google.ai.client.generativeai.type.GoogleGenerativeAIException
3131
import com.google.ai.client.generativeai.type.PromptBlockedException
32+
import com.google.ai.client.generativeai.type.RequestOptions
3233
import com.google.ai.client.generativeai.type.ResponseStoppedException
3334
import com.google.ai.client.generativeai.type.SafetySetting
3435
import com.google.ai.client.generativeai.type.SerializationException
@@ -45,13 +46,15 @@ import kotlinx.coroutines.flow.map
4546
* @property generationConfig configuration parameters to use for content generation
4647
* @property safetySettings the safety bounds to use during alongside prompts during content
4748
* generation
49+
* @property requestOptions configuration options to utilize during backend communication
4850
*/
4951
class GenerativeModel
5052
internal constructor(
5153
val modelName: String,
5254
val apiKey: String,
5355
val generationConfig: GenerationConfig? = null,
5456
val safetySettings: List<SafetySetting>? = null,
57+
val requestOptions: RequestOptions = RequestOptions(),
5558
private val controller: APIController
5659
) {
5760

@@ -61,7 +64,15 @@ internal constructor(
6164
apiKey: String,
6265
generationConfig: GenerationConfig? = null,
6366
safetySettings: List<SafetySetting>? = null,
64-
) : this(modelName, apiKey, generationConfig, safetySettings, APIController(apiKey, modelName))
67+
requestOptions: RequestOptions = RequestOptions(),
68+
) : this(
69+
modelName,
70+
apiKey,
71+
generationConfig,
72+
safetySettings,
73+
requestOptions,
74+
APIController(apiKey, modelName, requestOptions.apiVersion, requestOptions.timeout)
75+
)
6576

6677
/**
6778
* Generates a response from the backend with the provided [Content]s.

generativeai/src/main/java/com/google/ai/client/generativeai/internal/api/APIController.kt

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,15 @@ import io.ktor.http.ContentType
3737
import io.ktor.http.HttpStatusCode
3838
import io.ktor.http.contentType
3939
import io.ktor.serialization.kotlinx.json.json
40+
import kotlin.time.Duration
4041
import kotlinx.coroutines.CoroutineName
4142
import kotlinx.coroutines.flow.Flow
4243
import kotlinx.coroutines.flow.channelFlow
44+
import kotlinx.coroutines.flow.timeout
4345
import kotlinx.coroutines.launch
4446
import kotlinx.serialization.json.Json
4547

46-
// TODO: Should these stay here or be moved elsewhere?
47-
internal const val DOMAIN = "https://generativelanguage.googleapis.com/v1"
48+
internal const val DOMAIN = "https://generativelanguage.googleapis.com"
4849

4950
internal val JSON = Json {
5051
ignoreUnknownKeys = true
@@ -60,42 +61,46 @@ internal val JSON = Json {
6061
* Exposed primarily for DI in tests.
6162
* @property key The API key used for authentication.
6263
* @property model The model to use for generation.
64+
* @property apiVersion the endpoint version to communicate with.
65+
* @property timeout the maximum amount of time for a request to take in the initial exchange.
6366
*/
6467
internal class APIController(
6568
private val key: String,
6669
model: String,
67-
httpEngine: HttpClientEngine = OkHttp.create()
70+
private val apiVersion: String,
71+
private val timeout: Duration,
72+
httpEngine: HttpClientEngine = OkHttp.create(),
6873
) {
6974
private val model = fullModelName(model)
7075

7176
private val client =
7277
HttpClient(httpEngine) {
7378
install(HttpTimeout) {
74-
requestTimeoutMillis = HttpTimeout.INFINITE_TIMEOUT_MS
79+
requestTimeoutMillis = timeout.inWholeMilliseconds
7580
socketTimeoutMillis = 80_000
7681
}
7782
install(ContentNegotiation) { json(JSON) }
7883
}
7984

80-
suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse {
81-
return client
82-
.post("$DOMAIN/$model:generateContent") { applyCommonConfiguration(request) }
85+
suspend fun generateContent(request: GenerateContentRequest): GenerateContentResponse =
86+
client
87+
.post("$DOMAIN/$apiVersion/$model:generateContent") { applyCommonConfiguration(request) }
8388
.also { validateResponse(it) }
8489
.body()
85-
}
8690

8791
fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> {
88-
return client.postStream("$DOMAIN/$model:streamGenerateContent?alt=sse") {
92+
return client.postStream<GenerateContentResponse>(
93+
"$DOMAIN/$apiVersion/$model:streamGenerateContent?alt=sse"
94+
) {
8995
applyCommonConfiguration(request)
9096
}
9197
}
9298

93-
suspend fun countTokens(request: CountTokensRequest): CountTokensResponse {
94-
return client
95-
.post("$DOMAIN/$model:countTokens") { applyCommonConfiguration(request) }
99+
suspend fun countTokens(request: CountTokensRequest): CountTokensResponse =
100+
client
101+
.post("$DOMAIN/$apiVersion/$model:countTokens") { applyCommonConfiguration(request) }
96102
.also { validateResponse(it) }
97103
.body()
98-
}
99104

100105
private fun HttpRequestBuilder.applyCommonConfiguration(request: Request) {
101106
when (request) {
@@ -113,8 +118,7 @@ internal class APIController(
113118
*
114119
* Models must be prepended with the `models/` prefix when communicating with the backend.
115120
*/
116-
private fun fullModelName(name: String): String =
117-
name.takeIf { it.startsWith("models/") } ?: "models/$name"
121+
private fun fullModelName(name: String): String = name.takeIf { it.contains("/") } ?: "models/$name"
118122

119123
/**
120124
* Makes a POST request to the specified [url] and returns a [Flow] of deserialized response objects

generativeai/src/main/java/com/google/ai/client/generativeai/type/Exceptions.kt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package com.google.ai.client.generativeai.type
1818

1919
import com.google.ai.client.generativeai.GenerativeModel
2020
import io.ktor.serialization.JsonConvertException
21+
import kotlinx.coroutines.TimeoutCancellationException
2122

2223
/** Parent class for any errors that occur from [GenerativeModel]. */
2324
sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = null) :
@@ -39,6 +40,8 @@ sealed class GoogleGenerativeAIException(message: String, cause: Throwable? = nu
3940
"Something went wrong while trying to deserialize a response from the server.",
4041
cause
4142
)
43+
is TimeoutCancellationException ->
44+
RequestTimeoutException("The request failed to complete in the allotted time.")
4245
else -> UnknownException("Something unexpected happened.", cause)
4346
}
4447
}
@@ -84,6 +87,14 @@ class ResponseStoppedException(val response: GenerateContentResponse, cause: Thr
8487
cause
8588
)
8689

90+
/**
91+
* A request took too long to complete.
92+
*
93+
* Usually occurs due to a user specified [timeout][RequestOptions.timeout].
94+
*/
95+
class RequestTimeoutException(message: String, cause: Throwable? = null) :
96+
GoogleGenerativeAIException(message, cause)
97+
8798
/** Catch all case for exceptions not explicitly expected. */
8899
class UnknownException(message: String, cause: Throwable? = null) :
89100
GoogleGenerativeAIException(message, cause)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.ai.client.generativeai.type
18+
19+
import io.ktor.client.plugins.HttpTimeout
20+
import kotlin.time.Duration
21+
import kotlin.time.DurationUnit
22+
import kotlin.time.toDuration
23+
24+
/**
25+
* Configurable options unique to how requests to the backend are performed.
26+
*
27+
* @property timeout the maximum amount of time for a request to take, from the first request to
28+
* first response.
29+
* @property apiVersion the api endpoint to call.
30+
*/
31+
class RequestOptions(val timeout: Duration, val apiVersion: String = "v1") {
32+
@JvmOverloads
33+
constructor(
34+
timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS,
35+
apiVersion: String = "v1"
36+
) : this(
37+
(timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS),
38+
apiVersion
39+
)
40+
}

generativeai/src/test/java/com/google/ai/client/generativeai/GenerativeModelTests.kt

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,29 @@
1616

1717
package com.google.ai.client.generativeai
1818

19+
import com.google.ai.client.generativeai.type.RequestOptions
20+
import com.google.ai.client.generativeai.type.RequestTimeoutException
1921
import com.google.ai.client.generativeai.util.commonTest
22+
import com.google.ai.client.generativeai.util.createGenerativeModel
2023
import com.google.ai.client.generativeai.util.createResponses
24+
import com.google.ai.client.generativeai.util.doBlocking
2125
import com.google.ai.client.generativeai.util.prepareStreamingResponse
26+
import io.kotest.assertions.throwables.shouldThrow
2227
import io.kotest.matchers.shouldBe
28+
import io.kotest.matchers.string.shouldContain
29+
import io.ktor.client.engine.mock.MockEngine
30+
import io.ktor.client.engine.mock.respond
31+
import io.ktor.http.HttpHeaders
32+
import io.ktor.http.HttpStatusCode
33+
import io.ktor.http.headersOf
34+
import io.ktor.utils.io.ByteChannel
2335
import io.ktor.utils.io.close
2436
import io.ktor.utils.io.writeFully
2537
import kotlin.time.Duration.Companion.seconds
26-
import kotlinx.coroutines.flow.collect
2738
import kotlinx.coroutines.withTimeout
2839
import org.junit.Test
40+
import org.junit.runner.RunWith
41+
import org.junit.runners.Parameterized
2942

3043
internal class GenerativeModelTests {
3144
private val testTimeout = 5.seconds
@@ -45,4 +58,49 @@ internal class GenerativeModelTests {
4558
}
4659
}
4760
}
61+
62+
@Test
63+
fun `(generateContent) respects a custom timeout`() =
64+
commonTest(requestOptions = RequestOptions(2.seconds)) {
65+
shouldThrow<RequestTimeoutException> {
66+
withTimeout(testTimeout) { model.generateContent("d") }
67+
}
68+
}
69+
}
70+
71+
@RunWith(Parameterized::class)
72+
internal class ModelNamingTests(private val modelName: String, private val actualName: String) {
73+
74+
@Test
75+
fun `request should include right model name`() = doBlocking {
76+
val channel = ByteChannel(autoFlush = true)
77+
val mockEngine = MockEngine {
78+
respond(channel, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
79+
}
80+
prepareStreamingResponse(createResponses("Random")).forEach { channel.writeFully(it) }
81+
val model =
82+
createGenerativeModel(modelName, "super_cool_test_key", RequestOptions(), mockEngine)
83+
84+
withTimeout(5.seconds) {
85+
model.generateContentStream().collect {
86+
it.candidates.isEmpty() shouldBe false
87+
channel.close()
88+
}
89+
}
90+
91+
mockEngine.requestHistory.first().url.encodedPath shouldContain actualName
92+
}
93+
94+
companion object {
95+
@JvmStatic
96+
@Parameterized.Parameters
97+
fun data() =
98+
listOf(
99+
arrayOf("gemini-pro", "models/gemini-pro"),
100+
arrayOf("x/gemini-pro", "x/gemini-pro"),
101+
arrayOf("models/gemini-pro", "models/gemini-pro"),
102+
arrayOf("/modelname", "/modelname"),
103+
arrayOf("modifiedNaming/mymodel", "modifiedNaming/mymodel"),
104+
)
105+
}
48106
}

generativeai/src/test/java/com/google/ai/client/generativeai/util/tests.kt

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import com.google.ai.client.generativeai.internal.api.shared.Content
2828
import com.google.ai.client.generativeai.internal.api.shared.TextPart
2929
import com.google.ai.client.generativeai.internal.util.SSE_SEPARATOR
3030
import com.google.ai.client.generativeai.internal.util.send
31+
import com.google.ai.client.generativeai.type.RequestOptions
3132
import io.ktor.client.engine.mock.MockEngine
3233
import io.ktor.client.engine.mock.respond
3334
import io.ktor.http.HttpHeaders
@@ -93,19 +94,42 @@ internal typealias CommonTest = suspend CommonTestScope.() -> Unit
9394
* ```
9495
*
9596
* @param status An optional [HttpStatusCode] to return as a response
97+
* @param requestOptions Optional [RequestOptions] to utilize in the underlying controller
9698
* @param block The test contents themselves, with the [CommonTestScope] implicitly provided
9799
* @see CommonTestScope
98100
*/
99-
internal fun commonTest(status: HttpStatusCode = HttpStatusCode.OK, block: CommonTest) =
100-
doBlocking {
101-
val channel = ByteChannel(autoFlush = true)
102-
val mockEngine = MockEngine {
103-
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
104-
}
105-
val controller = APIController("super_cool_test_key", "gemini-pro", mockEngine)
106-
val model = GenerativeModel("gemini-pro", "super_cool_test_key", controller = controller)
107-
CommonTestScope(channel, model).block()
101+
internal fun commonTest(
102+
status: HttpStatusCode = HttpStatusCode.OK,
103+
requestOptions: RequestOptions = RequestOptions(),
104+
block: CommonTest
105+
) = doBlocking {
106+
val channel = ByteChannel(autoFlush = true)
107+
val mockEngine = MockEngine {
108+
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
108109
}
110+
val model = createGenerativeModel("gemini-pro", "super_cool_test_key", requestOptions, mockEngine)
111+
CommonTestScope(channel, model).block()
112+
}
113+
114+
/** Simple wrapper that guarantees the model and APIController are created using the same data */
115+
internal fun createGenerativeModel(
116+
name: String,
117+
apikey: String,
118+
requestOptions: RequestOptions = RequestOptions(),
119+
engine: MockEngine
120+
) =
121+
GenerativeModel(
122+
name,
123+
apikey,
124+
controller =
125+
APIController(
126+
"super_cool_test_key",
127+
name,
128+
requestOptions.apiVersion,
129+
requestOptions.timeout,
130+
engine
131+
)
132+
)
109133

110134
/**
111135
* A variant of [commonTest] for performing *streaming-based* snapshot tests.

0 commit comments

Comments
 (0)