Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions firebase-vertexai/firebase-vertexai.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ dependencies {
testImplementation(libs.kotlin.coroutines.test)
testImplementation(libs.robolectric)
testImplementation(libs.truth)
testImplementation(libs.mockito.core)

androidTestImplementation(libs.androidx.espresso.core)
androidTestImplementation(libs.androidx.test.junit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ internal constructor(
return GenerativeModel(
"projects/${firebaseApp.options.projectId}/locations/${location}/publishers/google/models/${modelName}",
firebaseApp.options.apiKey,
firebaseApp,
generationConfig,
safetySettings,
tools,
Expand Down Expand Up @@ -105,6 +106,7 @@ internal constructor(
return ImagenModel(
"projects/${firebaseApp.options.projectId}/locations/${location}/publishers/google/models/${modelName}",
firebaseApp.options.apiKey,
firebaseApp,
generationConfig,
safetySettings,
requestOptions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.firebase.vertexai

import android.graphics.Bitmap
import com.google.firebase.FirebaseApp
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
import com.google.firebase.auth.internal.InternalAuthProvider
import com.google.firebase.vertexai.common.APIController
Expand Down Expand Up @@ -59,6 +60,7 @@ internal constructor(
internal constructor(
modelName: String,
apiKey: String,
firebaseApp: FirebaseApp,
generationConfig: GenerationConfig? = null,
safetySettings: List<SafetySetting>? = null,
tools: List<Tool>? = null,
Expand All @@ -79,6 +81,7 @@ internal constructor(
modelName,
requestOptions,
"gl-kotlin/${KotlinVersion.CURRENT} fire/${BuildConfig.VERSION_NAME}",
firebaseApp,
AppCheckHeaderProvider(TAG, appCheckTokenProvider, internalAuthProvider),
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.firebase.vertexai

import com.google.firebase.FirebaseApp
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
import com.google.firebase.auth.internal.InternalAuthProvider
import com.google.firebase.vertexai.common.APIController
Expand Down Expand Up @@ -46,6 +47,7 @@ internal constructor(
internal constructor(
modelName: String,
apiKey: String,
firebaseApp: FirebaseApp,
generationConfig: ImagenGenerationConfig? = null,
safetySettings: ImagenSafetySettings? = null,
requestOptions: RequestOptions = RequestOptions(),
Expand All @@ -60,6 +62,7 @@ internal constructor(
modelName,
requestOptions,
"gl-kotlin/${KotlinVersion.CURRENT} fire/${BuildConfig.VERSION_NAME}",
firebaseApp,
AppCheckHeaderProvider(TAG, appCheckTokenProvider, internalAuthProvider),
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.google.firebase.vertexai.common

import android.util.Log
import com.google.firebase.Firebase
import com.google.firebase.FirebaseApp
import com.google.firebase.options
import com.google.firebase.vertexai.common.util.decodeToFlow
import com.google.firebase.vertexai.common.util.fullModelName
Expand Down Expand Up @@ -91,6 +92,9 @@ internal constructor(
private val requestOptions: RequestOptions,
httpEngine: HttpClientEngine,
private val apiClient: String,
private val firebaseApp: FirebaseApp,
private val appVersion: Int = 0,
private val googleAppId: String,
private val headerProvider: HeaderProvider?,
) {

Expand All @@ -99,8 +103,19 @@ internal constructor(
model: String,
requestOptions: RequestOptions,
apiClient: String,
firebaseApp: FirebaseApp,
headerProvider: HeaderProvider? = null,
) : this(key, model, requestOptions, OkHttp.create(), apiClient, headerProvider)
) : this(
key,
model,
requestOptions,
OkHttp.create(),
apiClient,
firebaseApp,
getVersionNumber(firebaseApp),
firebaseApp.options.applicationId,
headerProvider
)

private val model = fullModelName(model)

Expand Down Expand Up @@ -175,6 +190,10 @@ internal constructor(
contentType(ContentType.Application.Json)
header("x-goog-api-key", key)
header("x-goog-api-client", apiClient)
if (firebaseApp.isDataCollectionDefaultEnabled) {
header("X-Firebase-AppId", googleAppId)
header("X-Firebase-AppVersion", appVersion)
}
}

private suspend fun HttpRequestBuilder.applyHeaderProvider() {
Expand Down Expand Up @@ -240,6 +259,16 @@ internal constructor(

companion object {
private val TAG = APIController::class.java.simpleName

private fun getVersionNumber(app: FirebaseApp): Int {
try {
val context = app.applicationContext
return context.packageManager.getPackageInfo(context.packageName, 0).versionCode
} catch (e: Exception) {
Log.d(TAG, "Error while getting app version: ${e.message}")
return 0
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.firebase.vertexai

import com.google.firebase.FirebaseApp
import com.google.firebase.vertexai.common.APIController
import com.google.firebase.vertexai.common.JSON
import com.google.firebase.vertexai.common.util.doBlocking
Expand All @@ -42,10 +43,21 @@ import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.encodeToString
import org.junit.Before
import org.junit.Test
import org.mockito.Mockito

internal class GenerativeModelTesting {
private val TEST_CLIENT_ID = "test"
private val TEST_APP_ID = "1:android:12345"
private val TEST_VERSION = 1

private var mockFirebaseApp: FirebaseApp = Mockito.mock<FirebaseApp>()

@Before
fun setup() {
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false)
}

@Test
fun `system calling in request`() = doBlocking {
Expand All @@ -64,6 +76,9 @@ internal class GenerativeModelTesting {
RequestOptions(timeout = 5.seconds, endpoint = "https://my.custom.endpoint"),
mockEngine,
TEST_CLIENT_ID,
mockFirebaseApp,
TEST_VERSION,
TEST_APP_ID,
null,
)

Expand Down Expand Up @@ -109,6 +124,9 @@ internal class GenerativeModelTesting {
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
mockFirebaseApp,
TEST_VERSION,
TEST_APP_ID,
null,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.firebase.vertexai.common

import com.google.firebase.FirebaseApp
import com.google.firebase.vertexai.BuildConfig
import com.google.firebase.vertexai.common.util.commonTest
import com.google.firebase.vertexai.common.util.createResponses
Expand Down Expand Up @@ -49,12 +50,18 @@ import kotlinx.coroutines.withTimeout
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.JsonObject
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.mockito.Mockito

private val TEST_CLIENT_ID = "genai-android/test"

private val TEST_APP_ID = "1:android:12345"

private val TEST_VERSION = 1

internal class APIControllerTests {
private val testTimeout = 5.seconds

Expand Down Expand Up @@ -87,6 +94,14 @@ internal class APIControllerTests {

@OptIn(ExperimentalSerializationApi::class)
internal class RequestFormatTests {

private val mockFirebaseApp = Mockito.mock<FirebaseApp>()

@Before
fun setup() {
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false)
}

@Test
fun `using default endpoint`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
Expand All @@ -101,6 +116,9 @@ internal class RequestFormatTests {
RequestOptions(),
mockEngine,
"genai-android/${BuildConfig.VERSION_NAME}",
mockFirebaseApp,
TEST_VERSION,
TEST_APP_ID,
null,
)

Expand Down Expand Up @@ -128,6 +146,9 @@ internal class RequestFormatTests {
RequestOptions(timeout = 5.seconds, endpoint = "https://my.custom.endpoint"),
mockEngine,
TEST_CLIENT_ID,
mockFirebaseApp,
TEST_VERSION,
TEST_APP_ID,
null,
)

Expand Down Expand Up @@ -155,6 +176,9 @@ internal class RequestFormatTests {
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
mockFirebaseApp,
TEST_VERSION,
TEST_APP_ID,
null,
)

Expand All @@ -163,6 +187,35 @@ internal class RequestFormatTests {
mockEngine.requestHistory.first().headers["x-goog-api-client"] shouldBe TEST_CLIENT_ID
}

@Test
fun `ml monitoring header is set correctly if data collection is enabled`() = doBlocking {
val response = JSON.encodeToString(CountTokensResponse.Internal(totalTokens = 10))
val mockEngine = MockEngine {
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
}

Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(true)

val controller =
APIController(
"super_cool_test_key",
"gemini-pro-1.5",
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
mockFirebaseApp,
TEST_VERSION,
TEST_APP_ID,
null,
)

withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }

mockEngine.requestHistory.first().headers["X-Firebase-AppId"] shouldBe TEST_APP_ID
mockEngine.requestHistory.first().headers["X-Firebase-AppVersion"] shouldBe
TEST_VERSION.toString()
}

@Test
fun `ToolConfig serialization contains correct keys`() = doBlocking {
val channel = ByteChannel(autoFlush = true)
Expand All @@ -178,6 +231,9 @@ internal class RequestFormatTests {
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
mockFirebaseApp,
TEST_VERSION,
TEST_APP_ID,
null,
)

Expand Down Expand Up @@ -229,6 +285,9 @@ internal class RequestFormatTests {
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
mockFirebaseApp,
TEST_VERSION,
TEST_APP_ID,
testHeaderProvider,
)

Expand Down Expand Up @@ -263,6 +322,9 @@ internal class RequestFormatTests {
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
mockFirebaseApp,
TEST_VERSION,
TEST_APP_ID,
testHeaderProvider,
)

Expand All @@ -286,6 +348,9 @@ internal class RequestFormatTests {
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
mockFirebaseApp,
TEST_VERSION,
TEST_APP_ID,
null,
)

Expand All @@ -309,6 +374,12 @@ internal class RequestFormatTests {

@RunWith(Parameterized::class)
internal class ModelNamingTests(private val modelName: String, private val actualName: String) {
private val mockFirebaseApp = Mockito.mock<FirebaseApp>()

@Before
fun setup() {
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false)
}

@Test
fun `request should include right model name`() = doBlocking {
Expand All @@ -324,6 +395,9 @@ internal class ModelNamingTests(private val modelName: String, private val actua
RequestOptions(),
mockEngine,
TEST_CLIENT_ID,
mockFirebaseApp,
TEST_VERSION,
TEST_APP_ID,
null,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package com.google.firebase.vertexai.common.util

import com.google.firebase.FirebaseApp
import com.google.firebase.vertexai.common.APIController
import com.google.firebase.vertexai.common.JSON
import com.google.firebase.vertexai.type.Candidate
Expand All @@ -33,8 +34,11 @@ import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.encodeToString
import org.mockito.Mockito

private val TEST_CLIENT_ID = "genai-android/test"
private val TEST_APP_ID = "1:android:12345"
private val TEST_VERSION = 1

internal fun prepareStreamingResponse(
response: List<GenerateContentResponse.Internal>
Expand Down Expand Up @@ -90,6 +94,9 @@ internal fun commonTest(
requestOptions: RequestOptions = RequestOptions(),
block: CommonTest,
) = doBlocking {
val mockFirebaseApp = Mockito.mock<FirebaseApp>()
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false)

val channel = ByteChannel(autoFlush = true)
val apiController =
APIController(
Expand All @@ -100,6 +107,9 @@ internal fun commonTest(
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
},
TEST_CLIENT_ID,
mockFirebaseApp,
TEST_VERSION,
TEST_APP_ID,
null,
)
CommonTestScope(channel, apiController).block()
Expand Down
Loading