Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion firebase-ai/firebase-ai.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ android {
targetSdk = targetSdkVersion
baseline = file("lint-baseline.xml")
}
sourceSets { getByName("test").java.srcDirs("src/testUtil") }
sourceSets {
// getByName("test").java.srcDirs("src/testUtil")
getByName("androidTest") { kotlin.srcDirs("src/testUtil") }
}
}

// Enable Kotlin "Explicit API Mode". This causes the Kotlin compiler to fail if any
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.google.firebase.ai

import androidx.test.platform.app.InstrumentationRegistry
import com.google.firebase.FirebaseApp
import com.google.firebase.FirebaseOptions
import com.google.firebase.ai.type.GenerativeBackend

class AIModels {

companion object {
private val API_KEY: String = ""
private val APP_ID: String = ""
private val PROJECT_ID: String = "fireescape-integ-tests"
// General purpose models
var app: FirebaseApp? = null
var flash2Model: GenerativeModel? = null
var flash2LiteModel: GenerativeModel? = null

/** Returns a list of general purpose models to test */
fun getModels(): List<GenerativeModel> {
if (flash2Model == null) {
setup()
}
return listOf(flash2Model!!, flash2LiteModel!!)
}

fun app(): FirebaseApp {
if (app == null) {
setup()
}
return app!!
}

fun setup() {
val context = InstrumentationRegistry.getInstrumentation().context;
FirebaseApp.initializeApp(
context,
FirebaseOptions.Builder()
.setApiKey(API_KEY)
.setApplicationId(APP_ID)
.setProjectId(PROJECT_ID)
.build()
)
flash2Model =
FirebaseAI.getInstance(app!!, GenerativeBackend.vertexAI())
.generativeModel(
modelName = "gemini-2.0-flash",
)
flash2LiteModel =
FirebaseAI.getInstance(app!!, GenerativeBackend.vertexAI())
.generativeModel(
modelName = "gemini-2.0-flash-lite",
)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.firebase.ai

import com.google.firebase.ai.AIModels.Companion.getModels
import kotlinx.coroutines.runBlocking
import org.junit.Test

class AiIntegrationTests {
private val validator = TypesValidator()

@Test
fun testBasicResponse() {
for (model in getModels()) {
runBlocking {
val response = model.generateContent("pick a random color")
validator.validateResponse(response)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package com.google.firebase.ai

import android.graphics.Bitmap
import com.google.firebase.ai.AIModels.Companion.getModels
import com.google.firebase.ai.type.Content
import com.google.firebase.ai.type.ContentModality
import com.google.firebase.ai.type.CountTokensResponse
import java.io.ByteArrayOutputStream
import kotlinx.coroutines.runBlocking
import org.junit.Test

class CountTokensTests {

/** Ensures that the token count is expected for simple words. */
@Test
fun testCountTokensAmount() {
for (model in getModels()) {
runBlocking {
val response = model.countTokens("this is five different words")
assert(response.totalTokens == 5)
assert(response.promptTokensDetails.size == 1)
assert(response.promptTokensDetails[0].modality == ContentModality.TEXT)
assert(response.promptTokensDetails[0].tokenCount == 5)
}
}
}

/** Ensures that the model returns token counts in the correct modality for text. */
@Test
fun testCountTokensTextModality() {
for (model in getModels()) {
runBlocking {
val response = model.countTokens("this is a text prompt")
checkTokenCountsMatch(response)
assert(response.promptTokensDetails.size == 1)
assert(containsModality(response, ContentModality.TEXT))
}
}
}

/** Ensures that the model returns token counts in the correct modality for bitmap images. */
@Test
fun testCountTokensImageModality() {
for (model in getModels()) {
runBlocking {
val bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888)
val response = model.countTokens(bitmap)
checkTokenCountsMatch(response)
assert(response.promptTokensDetails.size == 1)
assert(containsModality(response, ContentModality.IMAGE))
}
}
}

/**
* Ensures the model can count tokens for multiple modalities at once, and return the
* corresponding token modalities correctly.
*/
@Test
fun testCountTokensTextAndImageModality() {
for (model in getModels()) {
runBlocking {
val bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888)
val response =
model.countTokens(
Content.Builder().text("this is text").build(),
Content.Builder().image(bitmap).build()
)
checkTokenCountsMatch(response)
assert(response.promptTokensDetails.size == 2)
assert(containsModality(response, ContentModality.TEXT))
assert(containsModality(response, ContentModality.IMAGE))
}
}
}

/**
* Ensures the model can count the tokens for a sent file. Additionally, ensures that the model
* treats this sent file as the modality of the mime type, in this case, a plaintext file has its
* tokens counted as `ContentModality.TEXT`.
*/
@Test
fun testCountTokensTextFileModality() {
for (model in getModels()) {
runBlocking {
val response =
model.countTokens(
Content.Builder().inlineData("this is text".toByteArray(), "text/plain").build()
)
checkTokenCountsMatch(response)
assert(response.totalTokens == 3)
assert(response.promptTokensDetails.size == 1)
assert(containsModality(response, ContentModality.TEXT))
}
}
}

/**
* Ensures the model can count the tokens for a sent file. Additionally, ensures that the model
* treats this sent file as the modality of the mime type, in this case, a PNG encoded bitmap has
* its tokens counted as `ContentModality.IMAGE`.
*/
@Test
fun testCountTokensImageFileModality() {
for (model in getModels()) {
runBlocking {
val bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888)
val stream = ByteArrayOutputStream()
bitmap.compress(Bitmap.CompressFormat.PNG, 1, stream)
val array = stream.toByteArray()
val response = model.countTokens(Content.Builder().inlineData(array, "image/png").build())
checkTokenCountsMatch(response)
assert(response.promptTokensDetails.size == 1)
assert(containsModality(response, ContentModality.IMAGE))
}
}
}

/**
* Ensures that nothing is free, that is, empty content contains no tokens. For some reason, this
* is treated as `ContentModality.TEXT`.
*/
@Test
fun testCountTokensNothingIsFree() {
for (model in getModels()) {
runBlocking {
val response = model.countTokens(Content.Builder().build())
checkTokenCountsMatch(response)
assert(response.totalTokens == 0)
assert(response.promptTokensDetails.size == 1)
assert(containsModality(response, ContentModality.TEXT))
}
}
}

/**
* Checks if the model can count the tokens for a sent file. Additionally, ensures that the model
* treats this sent file as the modality of the mime type, in this case, a JSON file is not
* recognized, and no tokens are counted. This ensures if/when the model can handle JSON, our
* testing makes us aware.
*/
@Test
fun testCountTokensJsonFileModality() {
for (model in getModels()) {
runBlocking {
val json =
"""
{
"foo": "bar",
"baz": 3,
"qux": [
{
"quux": [
1,
2
]
}
]
}
"""
.trimIndent()
val response =
model.countTokens(
Content.Builder().inlineData(json.toByteArray(), "application/json").build()
)
checkTokenCountsMatch(response)
assert(response.promptTokensDetails.isEmpty())
assert(response.totalTokens == 0)
}
}
}

fun checkTokenCountsMatch(response: CountTokensResponse) {
assert(sumTokenCount(response) == response.totalTokens)
}

fun sumTokenCount(response: CountTokensResponse): Int {
return response.promptTokensDetails.sumOf { it.tokenCount }
}

fun containsModality(response: CountTokensResponse, modality: ContentModality): Boolean {
for (token in response.promptTokensDetails) {
if (token.modality == modality) {
return true
}
}
return false
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package com.google.firebase.ai

import android.graphics.Bitmap
import com.google.firebase.ai.AIModels.Companion.getModels
import com.google.firebase.ai.type.Content
import kotlinx.coroutines.runBlocking
import org.junit.Test

class GenerateContentTests {
private val validator = TypesValidator()

/**
* Ensures the model can response to prompts and that the structure of this response is expected.
*/
@Test
fun testGenerateContent_BasicRequest() {
for (model in getModels()) {
runBlocking {
val response = model.generateContent("pick a random color")
validator.validateResponse(response)
}
}
}

/**
* Ensures that the model can answer very simple questions. Further testing the "logic" of the
* model and the content of the responses is prone to flaking, this test is also prone to that.
* This is probably the furthest we can consistently test for reasonable response structure, past
* sending the request and response back to the model and asking it if it fits our expectations.
*/
@Test
fun testGenerateContent_ColorMixing() {
for (model in getModels()) {
runBlocking {
val response = model.generateContent("what color is created when red and yellow are mixed?")
validator.validateResponse(response)
assert(response.text!!.contains("orange", true))
}
}
}

/**
* Ensures that the model can answer very simple questions. Further testing the "logic" of the
* model and the content of the responses is prone to flaking, this test is also prone to that.
* This is probably the furthest we can consistently test for reasonable response structure, past
* sending the request and response back to the model and asking it if it fits our expectations.
*/
@Test
fun testGenerateContent_CanSendImage() {
for (model in getModels()) {
runBlocking {
val bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888)
val yellow = Integer.parseUnsignedInt("FFFFFF00", 16)
bitmap.setPixel(3, 3, yellow)
bitmap.setPixel(6, 3, yellow)
bitmap.setPixel(3, 6, yellow)
bitmap.setPixel(4, 7, yellow)
bitmap.setPixel(5, 7, yellow)
bitmap.setPixel(6, 6, yellow)
val response =
model.generateContent(
Content.Builder().text("here is a tiny smile").image(bitmap).build()
)
validator.validateResponse(response)
}
}
}

@Test
fun testGenerateContent_Tools() {
for (model in getModels()) {
runBlocking {
val response = model.generateContent(Content.Builder().text("here is a tiny smile").build())
validator.validateResponse(response)
model.startChat()
}
}
}
}
Loading
Loading