diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt index 96414185742..311d333e2f2 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt @@ -83,7 +83,7 @@ internal constructor( APIController( apiKey, modelName, - requestOptions.toInternal(), + requestOptions, "gl-kotlin/${KotlinVersion.CURRENT} fire/${BuildConfig.VERSION_NAME}", object : HeaderProvider { override val timeout: Duration diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt index 113573bab1e..f81fa0fc99a 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt @@ -21,6 +21,7 @@ import androidx.annotation.VisibleForTesting import com.google.firebase.vertexai.common.server.FinishReason import com.google.firebase.vertexai.common.util.decodeToFlow import com.google.firebase.vertexai.common.util.fullModelName +import com.google.firebase.vertexai.type.RequestOptions import io.ktor.client.HttpClient import io.ktor.client.call.body import io.ktor.client.engine.HttpClientEngine @@ -44,7 +45,9 @@ import io.ktor.http.contentType import io.ktor.http.headersOf import io.ktor.serialization.kotlinx.json.json import io.ktor.utils.io.ByteChannel +import kotlin.math.max import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.flow.Flow @@ -115,7 +118,8 @@ internal constructor( HttpClient(httpEngine) { install(HttpTimeout) { requestTimeoutMillis = requestOptions.timeout.inWholeMilliseconds - socketTimeoutMillis = 80_000 + socketTimeoutMillis = + max(180.seconds.inWholeMilliseconds, requestOptions.timeout.inWholeMilliseconds) } install(ContentNegotiation) { json(JSON) } } diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/RequestOptions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/RequestOptions.kt deleted file mode 100644 index 658c0f65836..00000000000 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/RequestOptions.kt +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.firebase.vertexai.common - -import io.ktor.client.plugins.HttpTimeout -import kotlin.time.Duration -import kotlin.time.DurationUnit -import kotlin.time.toDuration - -/** - * Configurable options unique to how requests to the backend are performed. - * - * @property timeout the maximum amount of time for a request to take, from the first request to - * first response. - * @property apiVersion the api endpoint to call. - */ -internal class RequestOptions( - val timeout: Duration, - val apiVersion: String = "v1beta", - val endpoint: String = "https://generativelanguage.googleapis.com", -) { - @JvmOverloads - constructor( - timeout: Long? = HttpTimeout.INFINITE_TIMEOUT_MS, - apiVersion: String = "v1beta", - endpoint: String = "https://generativelanguage.googleapis.com", - ) : this( - (timeout ?: HttpTimeout.INFINITE_TIMEOUT_MS).toDuration(DurationUnit.MILLISECONDS), - apiVersion, - endpoint, - ) -} diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt index 2710501dded..3adf069a0ed 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt @@ -47,7 +47,6 @@ import com.google.firebase.vertexai.type.HarmSeverity import com.google.firebase.vertexai.type.ImagePart import com.google.firebase.vertexai.type.Part import com.google.firebase.vertexai.type.PromptFeedback -import com.google.firebase.vertexai.type.RequestOptions import com.google.firebase.vertexai.type.SafetyRating import com.google.firebase.vertexai.type.SafetySetting import com.google.firebase.vertexai.type.SerializationException @@ -63,9 +62,6 @@ import org.json.JSONObject private const val BASE_64_FLAGS = Base64.NO_WRAP -internal fun RequestOptions.toInternal() = - com.google.firebase.vertexai.common.RequestOptions(timeout, apiVersion, endpoint) - internal fun Content.toInternal() = com.google.firebase.vertexai.common.shared.Content( this.role ?: "user", diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/RequestOptions.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/RequestOptions.kt index 5a7d11f2cf8..0fedcdb5c4f 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/RequestOptions.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/RequestOptions.kt @@ -17,23 +17,26 @@ package com.google.firebase.vertexai.type import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds import kotlin.time.DurationUnit import kotlin.time.toDuration -/** - * Configurable options unique to how requests to the backend are performed. - * - * @property timeout the maximum amount of time for a request to take, from the first request to - * first response. - * @property apiVersion the api endpoint to call. - */ -class RequestOptions(val timeout: Duration) { - - internal val endpoint = "https://firebaseml.googleapis.com" - internal val apiVersion = "v2beta" +/** Configurable options unique to how requests to the backend are performed. */ +class RequestOptions +internal constructor( + internal val timeout: Duration, + internal val endpoint: String = "https://firebaseml.googleapis.com", + internal val apiVersion: String = "v2beta", +) { + /** + * Constructor for RequestOptions. + * + * @param timeoutInMillis the maximum amount of time, in milliseconds, for a request to take, from + * the first request to first response. + */ @JvmOverloads constructor( - timeout: Long? = Long.MAX_VALUE, - ) : this((timeout ?: Long.MAX_VALUE).toDuration(DurationUnit.MILLISECONDS)) + timeoutInMillis: Long = 180.seconds.inWholeMilliseconds + ) : this(timeout = timeoutInMillis.toDuration(DurationUnit.MILLISECONDS)) } diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt index b683c1ba742..582678cf306 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt @@ -26,6 +26,7 @@ import com.google.firebase.vertexai.common.util.commonTest import com.google.firebase.vertexai.common.util.createResponses import com.google.firebase.vertexai.common.util.doBlocking import com.google.firebase.vertexai.common.util.prepareStreamingResponse +import com.google.firebase.vertexai.type.RequestOptions import io.kotest.assertions.json.shouldContainJsonKey import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe @@ -107,7 +108,7 @@ internal class RequestFormatTests { } } - mockEngine.requestHistory.first().url.host shouldBe "generativelanguage.googleapis.com" + mockEngine.requestHistory.first().url.host shouldBe "firebaseml.googleapis.com" } @Test @@ -121,7 +122,7 @@ internal class RequestFormatTests { APIController( "super_cool_test_key", "gemini-pro-1.5", - RequestOptions(endpoint = "https://my.custom.endpoint"), + RequestOptions(timeout = 5.seconds, endpoint = "https://my.custom.endpoint"), mockEngine, TEST_CLIENT_ID, null, diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt index dba31e45730..8a7184e9851 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt @@ -22,10 +22,10 @@ import com.google.firebase.vertexai.common.APIController import com.google.firebase.vertexai.common.GenerateContentRequest import com.google.firebase.vertexai.common.GenerateContentResponse import com.google.firebase.vertexai.common.JSON -import com.google.firebase.vertexai.common.RequestOptions import com.google.firebase.vertexai.common.server.Candidate import com.google.firebase.vertexai.common.shared.Content import com.google.firebase.vertexai.common.shared.TextPart +import com.google.firebase.vertexai.type.RequestOptions import io.kotest.matchers.collections.shouldNotBeEmpty import io.kotest.matchers.nulls.shouldNotBeNull import io.ktor.http.HttpStatusCode diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt index 37c0581b372..7afb202c1d6 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt @@ -18,7 +18,7 @@ package com.google.firebase.vertexai.util import com.google.firebase.vertexai.GenerativeModel import com.google.firebase.vertexai.common.APIController -import com.google.firebase.vertexai.common.RequestOptions +import com.google.firebase.vertexai.type.RequestOptions import io.kotest.matchers.collections.shouldNotBeEmpty import io.kotest.matchers.nulls.shouldNotBeNull import io.ktor.http.HttpStatusCode