Skip to content

Commit 418af2f

Browse files
author
David Motsonashvili
committed
minor restructure and added tests
1 parent a324a1a commit 418af2f

File tree

5 files changed

+95
-12
lines changed

5 files changed

+95
-12
lines changed

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.firebase.vertexai.common
1818

19+
import android.content.Context
1920
import android.util.Log
2021
import com.google.firebase.Firebase
2122
import com.google.firebase.FirebaseApp
@@ -93,6 +94,8 @@ internal constructor(
9394
httpEngine: HttpClientEngine,
9495
private val apiClient: String,
9596
private val firebaseApp: FirebaseApp,
97+
private val appVersion: Int = 0,
98+
private val googleAppId: String,
9699
private val headerProvider: HeaderProvider?,
97100
) {
98101

@@ -103,7 +106,17 @@ internal constructor(
103106
apiClient: String,
104107
firebaseApp: FirebaseApp,
105108
headerProvider: HeaderProvider? = null,
106-
) : this(key, model, requestOptions, OkHttp.create(), apiClient, firebaseApp, headerProvider)
109+
) : this(
110+
key,
111+
model,
112+
requestOptions,
113+
OkHttp.create(),
114+
apiClient,
115+
firebaseApp,
116+
getVersionNumber(firebaseApp),
117+
firebaseApp.options.applicationId,
118+
headerProvider
119+
)
107120

108121
private val model = fullModelName(model)
109122

@@ -179,20 +192,11 @@ internal constructor(
179192
header("x-goog-api-key", key)
180193
header("x-goog-api-client", apiClient)
181194
if (firebaseApp.isDataCollectionDefaultEnabled) {
182-
header("X-Firebase-AppId", firebaseApp.options.applicationId)
183-
header("X-Firebase-AppVersion", getVersionNumber())
195+
header("X-Firebase-AppId", googleAppId)
196+
header("X-Firebase-AppVersion", appVersion)
184197
}
185198
}
186199

187-
private fun getVersionNumber(): Int {
188-
try {
189-
val context = firebaseApp.applicationContext
190-
return context.packageManager.getPackageInfo(context.packageName, 0).versionCode
191-
} catch (e: Exception) {
192-
Log.d(TAG, "Error while getting app version: ${e.message}")
193-
return 0
194-
}
195-
}
196200

197201
private suspend fun HttpRequestBuilder.applyHeaderProvider() {
198202
if (headerProvider != null) {
@@ -257,6 +261,16 @@ internal constructor(
257261

258262
companion object {
259263
private val TAG = APIController::class.java.simpleName
264+
265+
private fun getVersionNumber(app: FirebaseApp): Int {
266+
try {
267+
val context = app.applicationContext
268+
return context.packageManager.getPackageInfo(context.packageName, 0).versionCode
269+
} catch (e: Exception) {
270+
Log.d(TAG, "Error while getting app version: ${e.message}")
271+
return 0
272+
}
273+
}
260274
}
261275
}
262276

firebase-vertexai/src/test/java/com/google/firebase/vertexai/GenerativeModelTesting.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ import org.mockito.Mockito
4949

5050
internal class GenerativeModelTesting {
5151
private val TEST_CLIENT_ID = "test"
52+
private val TEST_APP_ID = "1:android:12345"
53+
private val TEST_VERSION = 1
5254

5355
private var mockFirebaseApp: FirebaseApp = Mockito.mock<FirebaseApp>()
5456

@@ -75,6 +77,8 @@ internal class GenerativeModelTesting {
7577
mockEngine,
7678
TEST_CLIENT_ID,
7779
mockFirebaseApp,
80+
TEST_VERSION,
81+
TEST_APP_ID,
7882
null,
7983
)
8084

@@ -121,6 +125,8 @@ internal class GenerativeModelTesting {
121125
mockEngine,
122126
TEST_CLIENT_ID,
123127
mockFirebaseApp,
128+
TEST_VERSION,
129+
TEST_APP_ID,
124130
null,
125131
)
126132

firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/APIControllerTests.kt

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616

1717
package com.google.firebase.vertexai.common
1818

19+
import android.content.Context
20+
import android.content.pm.PackageInfo
21+
import android.content.pm.PackageManager
1922
import com.google.firebase.FirebaseApp
23+
import com.google.firebase.FirebaseOptions
2024
import com.google.firebase.vertexai.BuildConfig
2125
import com.google.firebase.vertexai.common.util.commonTest
2226
import com.google.firebase.vertexai.common.util.createResponses
@@ -54,10 +58,16 @@ import org.junit.Before
5458
import org.junit.Test
5559
import org.junit.runner.RunWith
5660
import org.junit.runners.Parameterized
61+
import org.mockito.ArgumentMatchers.any
62+
import org.mockito.ArgumentMatchers.anyInt
5763
import org.mockito.Mockito
5864

5965
private val TEST_CLIENT_ID = "genai-android/test"
6066

67+
private val TEST_APP_ID = "1:android:12345"
68+
69+
private val TEST_VERSION = 1
70+
6171
internal class APIControllerTests {
6272
private val testTimeout = 5.seconds
6373

@@ -113,6 +123,8 @@ internal class RequestFormatTests {
113123
mockEngine,
114124
"genai-android/${BuildConfig.VERSION_NAME}",
115125
mockFirebaseApp,
126+
TEST_VERSION,
127+
TEST_APP_ID,
116128
null,
117129
)
118130

@@ -141,6 +153,8 @@ internal class RequestFormatTests {
141153
mockEngine,
142154
TEST_CLIENT_ID,
143155
mockFirebaseApp,
156+
TEST_VERSION,
157+
TEST_APP_ID,
144158
null,
145159
)
146160

@@ -169,6 +183,8 @@ internal class RequestFormatTests {
169183
mockEngine,
170184
TEST_CLIENT_ID,
171185
mockFirebaseApp,
186+
TEST_VERSION,
187+
TEST_APP_ID,
172188
null,
173189
)
174190

@@ -177,6 +193,35 @@ internal class RequestFormatTests {
177193
mockEngine.requestHistory.first().headers["x-goog-api-client"] shouldBe TEST_CLIENT_ID
178194
}
179195

196+
@Test
197+
fun `ml monitoring header is set correctly if data collection is enabled`() = doBlocking {
198+
val response = JSON.encodeToString(CountTokensResponse.Internal(totalTokens = 10))
199+
val mockEngine = MockEngine {
200+
respond(response, HttpStatusCode.OK, headersOf(HttpHeaders.ContentType, "application/json"))
201+
}
202+
203+
Mockito.`when`(mockFirebaseApp.isDataCollectionDefaultEnabled).thenReturn(true)
204+
205+
val controller =
206+
APIController(
207+
"super_cool_test_key",
208+
"gemini-pro-1.5",
209+
RequestOptions(),
210+
mockEngine,
211+
TEST_CLIENT_ID,
212+
mockFirebaseApp,
213+
TEST_VERSION,
214+
TEST_APP_ID,
215+
null,
216+
)
217+
218+
withTimeout(5.seconds) { controller.countTokens(textCountTokenRequest("cats")) }
219+
220+
mockEngine.requestHistory.first().headers["X-Firebase-AppId"] shouldBe TEST_APP_ID
221+
mockEngine.requestHistory.first().headers["X-Firebase-AppVersion"] shouldBe
222+
TEST_VERSION.toString()
223+
}
224+
180225
@Test
181226
fun `ToolConfig serialization contains correct keys`() = doBlocking {
182227
val channel = ByteChannel(autoFlush = true)
@@ -193,6 +238,8 @@ internal class RequestFormatTests {
193238
mockEngine,
194239
TEST_CLIENT_ID,
195240
mockFirebaseApp,
241+
TEST_VERSION,
242+
TEST_APP_ID,
196243
null,
197244
)
198245

@@ -245,6 +292,8 @@ internal class RequestFormatTests {
245292
mockEngine,
246293
TEST_CLIENT_ID,
247294
mockFirebaseApp,
295+
TEST_VERSION,
296+
TEST_APP_ID,
248297
testHeaderProvider,
249298
)
250299

@@ -280,6 +329,8 @@ internal class RequestFormatTests {
280329
mockEngine,
281330
TEST_CLIENT_ID,
282331
mockFirebaseApp,
332+
TEST_VERSION,
333+
TEST_APP_ID,
283334
testHeaderProvider,
284335
)
285336

@@ -304,6 +355,8 @@ internal class RequestFormatTests {
304355
mockEngine,
305356
TEST_CLIENT_ID,
306357
mockFirebaseApp,
358+
TEST_VERSION,
359+
TEST_APP_ID,
307360
null,
308361
)
309362

@@ -349,6 +402,8 @@ internal class ModelNamingTests(private val modelName: String, private val actua
349402
mockEngine,
350403
TEST_CLIENT_ID,
351404
mockFirebaseApp,
405+
TEST_VERSION,
406+
TEST_APP_ID,
352407
null,
353408
)
354409

firebase-vertexai/src/test/java/com/google/firebase/vertexai/common/util/tests.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ import kotlinx.serialization.encodeToString
3737
import org.mockito.Mockito
3838

3939
private val TEST_CLIENT_ID = "genai-android/test"
40+
private val TEST_APP_ID = "1:android:12345"
41+
private val TEST_VERSION = 1
4042

4143
internal fun prepareStreamingResponse(
4244
response: List<GenerateContentResponse.Internal>
@@ -106,6 +108,8 @@ internal fun commonTest(
106108
},
107109
TEST_CLIENT_ID,
108110
mockFirebaseApp,
111+
TEST_VERSION,
112+
TEST_APP_ID,
109113
null,
110114
)
111115
CommonTestScope(channel, apiController).block()

firebase-vertexai/src/test/java/com/google/firebase/vertexai/util/tests.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ import kotlinx.coroutines.launch
3939
import org.mockito.Mockito
4040

4141
private val TEST_CLIENT_ID = "firebase-vertexai-android/test"
42+
private val TEST_APP_ID = "1:android:12345"
43+
private val TEST_VERSION = 1
4244

4345
/** String separator used in SSE communication to signal the end of a message. */
4446
internal const val SSE_SEPARATOR = "\r\n\r\n"
@@ -115,6 +117,8 @@ internal fun commonTest(
115117
},
116118
TEST_CLIENT_ID,
117119
mockFirebaseApp,
120+
TEST_VERSION,
121+
TEST_APP_ID,
118122
null,
119123
)
120124
val model = GenerativeModel("cool-model-name", controller = apiController)

0 commit comments

Comments
 (0)