diff --git a/firebase-vertexai/firebase-vertexai.gradle.kts b/firebase-vertexai/firebase-vertexai.gradle.kts index b43909559e0..08942b6bedf 100644 --- a/firebase-vertexai/firebase-vertexai.gradle.kts +++ b/firebase-vertexai/firebase-vertexai.gradle.kts @@ -115,6 +115,7 @@ dependencies { testImplementation(libs.kotlin.coroutines.test) testImplementation(libs.robolectric) testImplementation(libs.truth) + testImplementation(libs.mockito.core) androidTestImplementation(libs.androidx.espresso.core) androidTestImplementation(libs.androidx.test.junit) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt index b89e5671992..1790ec0c300 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt @@ -71,6 +71,7 @@ internal constructor( return GenerativeModel( "projects/${firebaseApp.options.projectId}/locations/${location}/publishers/google/models/${modelName}", firebaseApp.options.apiKey, + firebaseApp, generationConfig, safetySettings, tools, @@ -105,6 +106,7 @@ internal constructor( return ImagenModel( "projects/${firebaseApp.options.projectId}/locations/${location}/publishers/google/models/${modelName}", firebaseApp.options.apiKey, + firebaseApp, generationConfig, safetySettings, requestOptions, diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt index 12d89ab5b59..3520aff2238 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/GenerativeModel.kt @@ -17,6 +17,7 @@ package com.google.firebase.vertexai import android.graphics.Bitmap +import com.google.firebase.FirebaseApp import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider import com.google.firebase.auth.internal.InternalAuthProvider import com.google.firebase.vertexai.common.APIController @@ -59,6 +60,7 @@ internal constructor( internal constructor( modelName: String, apiKey: String, + firebaseApp: FirebaseApp, generationConfig: GenerationConfig? = null, safetySettings: List? = null, tools: List? = null, @@ -79,6 +81,7 @@ internal constructor( modelName, requestOptions, "gl-kotlin/${KotlinVersion.CURRENT} fire/${BuildConfig.VERSION_NAME}", + firebaseApp, AppCheckHeaderProvider(TAG, appCheckTokenProvider, internalAuthProvider), ), ) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/ImagenModel.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/ImagenModel.kt index 583ef24bcc4..fa33ee6e327 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/ImagenModel.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/ImagenModel.kt @@ -16,6 +16,7 @@ package com.google.firebase.vertexai +import com.google.firebase.FirebaseApp import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider import com.google.firebase.auth.internal.InternalAuthProvider import com.google.firebase.vertexai.common.APIController @@ -46,6 +47,7 @@ internal constructor( internal constructor( modelName: String, apiKey: String, + firebaseApp: FirebaseApp, generationConfig: ImagenGenerationConfig? = null, safetySettings: ImagenSafetySettings? = null, requestOptions: RequestOptions = RequestOptions(), @@ -60,6 +62,7 @@ internal constructor( modelName, requestOptions, "gl-kotlin/${KotlinVersion.CURRENT} fire/${BuildConfig.VERSION_NAME}", + firebaseApp, AppCheckHeaderProvider(TAG, appCheckTokenProvider, internalAuthProvider), ), ) diff --git a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt index f8bfe0bc24f..c67e21ccf23 100644 --- a/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt +++ b/firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt @@ -18,6 +18,7 @@ package com.google.firebase.vertexai.common import android.util.Log import com.google.firebase.Firebase +import com.google.firebase.FirebaseApp import com.google.firebase.options import com.google.firebase.vertexai.common.util.decodeToFlow import com.google.firebase.vertexai.common.util.fullModelName @@ -91,6 +92,9 @@ internal constructor( private val requestOptions: RequestOptions, httpEngine: HttpClientEngine, private val apiClient: String, + private val firebaseApp: FirebaseApp, + private val appVersion: Int = 0, + private val googleAppId: String, private val headerProvider: HeaderProvider?, ) { @@ -99,8 +103,19 @@ internal constructor( model: String, requestOptions: RequestOptions, apiClient: String, + firebaseApp: FirebaseApp, headerProvider: HeaderProvider? = null, - ) : this(key, model, requestOptions, OkHttp.create(), apiClient, headerProvider) + ) : this( + key, + model, + requestOptions, + OkHttp.create(), + apiClient, + firebaseApp, + getVersionNumber(firebaseApp), + firebaseApp.options.applicationId, + headerProvider + ) private val model = fullModelName(model) @@ -175,6 +190,10 @@ internal constructor( contentType(ContentType.Application.Json) header("x-goog-api-key", key) header("x-goog-api-client", apiClient) + if (firebaseApp.isDataCollectionDefaultEnabled) { + header("X-Firebase-AppId", googleAppId) + header("X-Firebase-AppVersion", appVersion) + } } private suspend fun HttpRequestBuilder.applyHeaderProvider() { @@ -240,6 +259,16 @@ internal constructor( companion object { private val TAG = APIController::class.java.simpleName + + private fun getVersionNumber(app: FirebaseApp): Int { + try { + val context = app.applicationContext + return context.packageManager.getPackageInfo(context.packageName, 0).versionCode + } catch (e: Exception) { + Log.d(TAG, "Error while getting app version: ${e.message}") + return 0 + } + } } } diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt index d4c2ad37926..e66918ad52f 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt @@ -16,6 +16,7 @@ package com.google.firebase.vertexai +import com.google.firebase.FirebaseApp import com.google.firebase.vertexai.common.APIController import com.google.firebase.vertexai.common.JSON import com.google.firebase.vertexai.common.util.doBlocking @@ -42,10 +43,21 @@ import kotlin.time.Duration.Companion.seconds import kotlinx.coroutines.withTimeout import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.encodeToString +import org.junit.Before import org.junit.Test +import org.mockito.Mockito internal class GenerativeModelTesting { private val TEST_CLIENT_ID = "test" + private val TEST_APP_ID = "1:android:12345" + private val TEST_VERSION = 1 + + private var mockFirebaseApp: FirebaseApp = Mockito.mock() + + @Before + fun setup() { + Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false) + } @Test fun `system calling in request`() = doBlocking { @@ -64,6 +76,9 @@ internal class GenerativeModelTesting { RequestOptions(timeout = 5.seconds, endpoint = "https://my.custom.endpoint"), mockEngine, TEST_CLIENT_ID, + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, null, ) @@ -109,6 +124,9 @@ internal class GenerativeModelTesting { RequestOptions(), mockEngine, TEST_CLIENT_ID, + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, null, ) diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt index 29b52b81d1b..0d668849156 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt @@ -16,6 +16,7 @@ package com.google.firebase.vertexai.common +import com.google.firebase.FirebaseApp import com.google.firebase.vertexai.BuildConfig import com.google.firebase.vertexai.common.util.commonTest import com.google.firebase.vertexai.common.util.createResponses @@ -49,12 +50,18 @@ import kotlinx.coroutines.withTimeout import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.encodeToString import kotlinx.serialization.json.JsonObject +import org.junit.Before import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Parameterized +import org.mockito.Mockito private val TEST_CLIENT_ID = "genai-android/test" +private val TEST_APP_ID = "1:android:12345" + +private val TEST_VERSION = 1 + internal class APIControllerTests { private val testTimeout = 5.seconds @@ -87,6 +94,14 @@ internal class APIControllerTests { @OptIn(ExperimentalSerializationApi::class) internal class RequestFormatTests { + + private val mockFirebaseApp = Mockito.mock() + + @Before + fun setup() { + Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false) + } + @Test fun `using default endpoint`() = doBlocking { val channel = ByteChannel(autoFlush = true) @@ -101,6 +116,9 @@ internal class RequestFormatTests { RequestOptions(), mockEngine, "genai-android/${BuildConfig.VERSION_NAME}", + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, null, ) @@ -128,6 +146,9 @@ internal class RequestFormatTests { RequestOptions(timeout = 5.seconds, endpoint = "https://my.custom.endpoint"), mockEngine, TEST_CLIENT_ID, + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, null, ) @@ -155,6 +176,9 @@ internal class RequestFormatTests { RequestOptions(), mockEngine, TEST_CLIENT_ID, + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, null, ) @@ -163,6 +187,35 @@ internal class RequestFormatTests { mockEngine.requestHistory.first().headers["x-goog-api-client"] shouldBe TEST_CLIENT_ID } + @Test + fun `ml monitoring header is set correctly if data collection is enabled`() = doBlocking { + val response = JSON.encodeToString(CountTokensResponse.Internal(totalTokens = 10)) + val mockEngine = MockEngine { + respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json")) + } + + Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(true) + + val controller = + APIController( + "super_cool_test_key", + "gemini-pro-1.5", + RequestOptions(), + mockEngine, + TEST_CLIENT_ID, + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, + null, + ) + + withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) } + + mockEngine.requestHistory.first().headers["X-Firebase-AppId"] shouldBe TEST_APP_ID + mockEngine.requestHistory.first().headers["X-Firebase-AppVersion"] shouldBe + TEST_VERSION.toString() + } + @Test fun `ToolConfig serialization contains correct keys`() = doBlocking { val channel = ByteChannel(autoFlush = true) @@ -178,6 +231,9 @@ internal class RequestFormatTests { RequestOptions(), mockEngine, TEST_CLIENT_ID, + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, null, ) @@ -229,6 +285,9 @@ internal class RequestFormatTests { RequestOptions(), mockEngine, TEST_CLIENT_ID, + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, testHeaderProvider, ) @@ -263,6 +322,9 @@ internal class RequestFormatTests { RequestOptions(), mockEngine, TEST_CLIENT_ID, + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, testHeaderProvider, ) @@ -286,6 +348,9 @@ internal class RequestFormatTests { RequestOptions(), mockEngine, TEST_CLIENT_ID, + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, null, ) @@ -309,6 +374,12 @@ internal class RequestFormatTests { @RunWith(Parameterized::class) internal class ModelNamingTests(private val modelName: String, private val actualName: String) { + private val mockFirebaseApp = Mockito.mock() + + @Before + fun setup() { + Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false) + } @Test fun `request should include right model name`() = doBlocking { @@ -324,6 +395,9 @@ internal class ModelNamingTests(private val modelName: String, private val actua RequestOptions(), mockEngine, TEST_CLIENT_ID, + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, null, ) diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt index 855c8aa4a8b..320cf381467 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt @@ -18,6 +18,7 @@ package com.google.firebase.vertexai.common.util +import com.google.firebase.FirebaseApp import com.google.firebase.vertexai.common.APIController import com.google.firebase.vertexai.common.JSON import com.google.firebase.vertexai.type.Candidate @@ -33,8 +34,11 @@ import io.ktor.http.headersOf import io.ktor.utils.io.ByteChannel import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.encodeToString +import org.mockito.Mockito private val TEST_CLIENT_ID = "genai-android/test" +private val TEST_APP_ID = "1:android:12345" +private val TEST_VERSION = 1 internal fun prepareStreamingResponse( response: List @@ -90,6 +94,9 @@ internal fun commonTest( requestOptions: RequestOptions = RequestOptions(), block: CommonTest, ) = doBlocking { + val mockFirebaseApp = Mockito.mock() + Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false) + val channel = ByteChannel(autoFlush = true) val apiController = APIController( @@ -100,6 +107,9 @@ internal fun commonTest( respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json")) }, TEST_CLIENT_ID, + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, null, ) CommonTestScope(channel, apiController).block() diff --git a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt index 4f648735396..a683c1d5032 100644 --- a/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt +++ b/firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt @@ -18,6 +18,7 @@ package com.google.firebase.vertexai.util +import com.google.firebase.FirebaseApp import com.google.firebase.vertexai.GenerativeModel import com.google.firebase.vertexai.ImagenModel import com.google.firebase.vertexai.common.APIController @@ -35,8 +36,11 @@ import io.ktor.utils.io.close import io.ktor.utils.io.writeFully import java.io.File import kotlinx.coroutines.launch +import org.mockito.Mockito private val TEST_CLIENT_ID = "firebase-vertexai-android/test" +private val TEST_APP_ID = "1:android:12345" +private val TEST_VERSION = 1 /** String separator used in SSE communication to signal the end of a message. */ internal const val SSE_SEPARATOR = "\r\n\r\n" @@ -100,6 +104,9 @@ internal fun commonTest( block: CommonTest, ) = doBlocking { val channel = ByteChannel(autoFlush = true) + val mockFirebaseApp = Mockito.mock() + Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(false) + val apiController = APIController( "super_cool_test_key", @@ -109,6 +116,9 @@ internal fun commonTest( respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json")) }, TEST_CLIENT_ID, + mockFirebaseApp, + TEST_VERSION, + TEST_APP_ID, null, ) val model = GenerativeModel("cool-model-name", controller = apiController)