diff --git a/firebase-ai/firebase-ai.gradle.kts b/firebase-ai/firebase-ai.gradle.kts index de29f44b0a1..3648c0bb80c 100644 --- a/firebase-ai/firebase-ai.gradle.kts +++ b/firebase-ai/firebase-ai.gradle.kts @@ -67,7 +67,10 @@ android { targetSdk = targetSdkVersion baseline = file("lint-baseline.xml") } - sourceSets { getByName("test").java.srcDirs("src/testUtil") } + sourceSets { + // getByName("test").java.srcDirs("src/testUtil") + getByName("androidTest") { kotlin.srcDirs("src/testUtil") } + } } // Enable Kotlin "Explicit API Mode". This causes the Kotlin compiler to fail if any diff --git a/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/AIModels.kt b/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/AIModels.kt new file mode 100644 index 00000000000..2b151a7b34c --- /dev/null +++ b/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/AIModels.kt @@ -0,0 +1,72 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.firebase.ai + +import androidx.test.platform.app.InstrumentationRegistry +import com.google.firebase.FirebaseApp +import com.google.firebase.FirebaseOptions +import com.google.firebase.ai.type.GenerativeBackend + +class AIModels { + + companion object { + private val API_KEY: String = "" + private val APP_ID: String = "" + private val PROJECT_ID: String = "fireescape-integ-tests" + // General purpose models + var app: FirebaseApp? = null + var flash2Model: GenerativeModel? = null + var flash2LiteModel: GenerativeModel? = null + + /** Returns a list of general purpose models to test */ + fun getModels(): List { + if (flash2Model == null) { + setup() + } + return listOf(flash2Model!!, flash2LiteModel!!) + } + + fun app(): FirebaseApp { + if (app == null) { + setup() + } + return app!! + } + + fun setup() { + val context = InstrumentationRegistry.getInstrumentation().context + app = + FirebaseApp.initializeApp( + context, + FirebaseOptions.Builder() + .setApiKey(API_KEY) + .setApplicationId(APP_ID) + .setProjectId(PROJECT_ID) + .build() + ) + flash2Model = + FirebaseAI.getInstance(app!!, GenerativeBackend.vertexAI()) + .generativeModel( + modelName = "gemini-2.0-flash", + ) + flash2LiteModel = + FirebaseAI.getInstance(app!!, GenerativeBackend.vertexAI()) + .generativeModel( + modelName = "gemini-2.0-flash-lite", + ) + } + } +} diff --git a/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/CountTokensTests.kt b/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/CountTokensTests.kt new file mode 100644 index 00000000000..04ff6262ee8 --- /dev/null +++ b/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/CountTokensTests.kt @@ -0,0 +1,204 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.firebase.ai + +import android.graphics.Bitmap +import com.google.firebase.ai.AIModels.Companion.getModels +import com.google.firebase.ai.type.Content +import com.google.firebase.ai.type.ContentModality +import com.google.firebase.ai.type.CountTokensResponse +import java.io.ByteArrayOutputStream +import kotlinx.coroutines.runBlocking +import org.junit.Test + +class CountTokensTests { + + /** Ensures that the token count is expected for simple words. */ + @Test + fun testCountTokensAmount() { + for (model in getModels()) { + runBlocking { + val response = model.countTokens("this is five different words") + assert(response.totalTokens == 5) + assert(response.promptTokensDetails.size == 1) + assert(response.promptTokensDetails[0].modality == ContentModality.TEXT) + assert(response.promptTokensDetails[0].tokenCount == 5) + } + } + } + + /** Ensures that the model returns token counts in the correct modality for text. */ + @Test + fun testCountTokensTextModality() { + for (model in getModels()) { + runBlocking { + val response = model.countTokens("this is a text prompt") + checkTokenCountsMatch(response) + assert(response.promptTokensDetails.size == 1) + assert(containsModality(response, ContentModality.TEXT)) + } + } + } + + /** Ensures that the model returns token counts in the correct modality for bitmap images. */ + @Test + fun testCountTokensImageModality() { + for (model in getModels()) { + runBlocking { + val bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888) + val response = model.countTokens(bitmap) + checkTokenCountsMatch(response) + assert(response.promptTokensDetails.size == 1) + assert(containsModality(response, ContentModality.IMAGE)) + } + } + } + + /** + * Ensures the model can count tokens for multiple modalities at once, and return the + * corresponding token modalities correctly. + */ + @Test + fun testCountTokensTextAndImageModality() { + for (model in getModels()) { + runBlocking { + val bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888) + val response = + model.countTokens( + Content.Builder().text("this is text").build(), + Content.Builder().image(bitmap).build() + ) + checkTokenCountsMatch(response) + assert(response.promptTokensDetails.size == 2) + assert(containsModality(response, ContentModality.TEXT)) + assert(containsModality(response, ContentModality.IMAGE)) + } + } + } + + /** + * Ensures the model can count the tokens for a sent file. Additionally, ensures that the model + * treats this sent file as the modality of the mime type, in this case, a plaintext file has its + * tokens counted as `ContentModality.TEXT`. + */ + @Test + fun testCountTokensTextFileModality() { + for (model in getModels()) { + runBlocking { + val response = + model.countTokens( + Content.Builder().inlineData("this is text".toByteArray(), "text/plain").build() + ) + checkTokenCountsMatch(response) + assert(response.totalTokens == 3) + assert(response.promptTokensDetails.size == 1) + assert(containsModality(response, ContentModality.TEXT)) + } + } + } + + /** + * Ensures the model can count the tokens for a sent file. Additionally, ensures that the model + * treats this sent file as the modality of the mime type, in this case, a PNG encoded bitmap has + * its tokens counted as `ContentModality.IMAGE`. + */ + @Test + fun testCountTokensImageFileModality() { + for (model in getModels()) { + runBlocking { + val bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888) + val stream = ByteArrayOutputStream() + bitmap.compress(Bitmap.CompressFormat.PNG, 1, stream) + val array = stream.toByteArray() + val response = model.countTokens(Content.Builder().inlineData(array, "image/png").build()) + checkTokenCountsMatch(response) + assert(response.promptTokensDetails.size == 1) + assert(containsModality(response, ContentModality.IMAGE)) + } + } + } + + /** + * Ensures that nothing is free, that is, empty content contains no tokens. For some reason, this + * is treated as `ContentModality.TEXT`. + */ + @Test + fun testCountTokensNothingIsFree() { + for (model in getModels()) { + runBlocking { + val response = model.countTokens(Content.Builder().build()) + checkTokenCountsMatch(response) + assert(response.totalTokens == 0) + assert(response.promptTokensDetails.size == 1) + assert(containsModality(response, ContentModality.TEXT)) + } + } + } + + /** + * Checks if the model can count the tokens for a sent file. Additionally, ensures that the model + * treats this sent file as the modality of the mime type, in this case, a JSON file is not + * recognized, and no tokens are counted. This ensures if/when the model can handle JSON, our + * testing makes us aware. + */ + @Test + fun testCountTokensJsonFileModality() { + for (model in getModels()) { + runBlocking { + val json = + """ + { + "foo": "bar", + "baz": 3, + "qux": [ + { + "quux": [ + 1, + 2 + ] + } + ] + } + """ + .trimIndent() + val response = + model.countTokens( + Content.Builder().inlineData(json.toByteArray(), "application/json").build() + ) + checkTokenCountsMatch(response) + assert(response.promptTokensDetails.isEmpty()) + assert(response.totalTokens == 0) + } + } + } + + fun checkTokenCountsMatch(response: CountTokensResponse) { + assert(sumTokenCount(response) == response.totalTokens) + } + + fun sumTokenCount(response: CountTokensResponse): Int { + return response.promptTokensDetails.sumOf { it.tokenCount } + } + + fun containsModality(response: CountTokensResponse, modality: ContentModality): Boolean { + for (token in response.promptTokensDetails) { + if (token.modality == modality) { + return true + } + } + return false + } +} diff --git a/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/GenerateContentTests.kt b/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/GenerateContentTests.kt new file mode 100644 index 00000000000..7de068311ef --- /dev/null +++ b/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/GenerateContentTests.kt @@ -0,0 +1,83 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.firebase.ai + +import android.graphics.Bitmap +import com.google.firebase.ai.AIModels.Companion.getModels +import com.google.firebase.ai.type.Content +import kotlinx.coroutines.runBlocking +import org.junit.Test + +class GenerateContentTests { + private val validator = TypesValidator() + + /** + * Ensures the model can response to prompts and that the structure of this response is expected. + */ + @Test + fun testGenerateContent_BasicRequest() { + for (model in getModels()) { + runBlocking { + val response = model.generateContent("pick a random color") + validator.validateResponse(response) + } + } + } + + /** + * Ensures that the model can answer very simple questions. Further testing the "logic" of the + * model and the content of the responses is prone to flaking, this test is also prone to that. + * This is probably the furthest we can consistently test for reasonable response structure, past + * sending the request and response back to the model and asking it if it fits our expectations. + */ + @Test + fun testGenerateContent_ColorMixing() { + for (model in getModels()) { + runBlocking { + val response = model.generateContent("what color is created when red and yellow are mixed?") + validator.validateResponse(response) + assert(response.text!!.contains("orange", true)) + } + } + } + + /** + * Ensures that the model can answer very simple questions. Further testing the "logic" of the + * model and the content of the responses is prone to flaking, this test is also prone to that. + * This is probably the furthest we can consistently test for reasonable response structure, past + * sending the request and response back to the model and asking it if it fits our expectations. + */ + @Test + fun testGenerateContent_CanSendImage() { + for (model in getModels()) { + runBlocking { + val bitmap = Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888) + val yellow = Integer.parseUnsignedInt("FFFFFF00", 16) + bitmap.setPixel(3, 3, yellow) + bitmap.setPixel(6, 3, yellow) + bitmap.setPixel(3, 6, yellow) + bitmap.setPixel(4, 7, yellow) + bitmap.setPixel(5, 7, yellow) + bitmap.setPixel(6, 6, yellow) + val response = + model.generateContent( + Content.Builder().text("here is a tiny smile").image(bitmap).build() + ) + validator.validateResponse(response) + } + } + } +} diff --git a/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/ToolTests.kt b/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/ToolTests.kt new file mode 100644 index 00000000000..92ebbf4fc97 --- /dev/null +++ b/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/ToolTests.kt @@ -0,0 +1,312 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.firebase.ai + +import com.google.firebase.ai.AIModels.Companion.app +import com.google.firebase.ai.type.Content +import com.google.firebase.ai.type.FunctionCallingConfig +import com.google.firebase.ai.type.FunctionDeclaration +import com.google.firebase.ai.type.FunctionResponsePart +import com.google.firebase.ai.type.GenerativeBackend +import com.google.firebase.ai.type.Schema +import com.google.firebase.ai.type.Tool +import com.google.firebase.ai.type.ToolConfig +import com.google.firebase.ai.type.content +import io.ktor.util.toLowerCasePreservingASCIIRules +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.booleanOrNull +import kotlinx.serialization.json.doubleOrNull +import kotlinx.serialization.json.float +import kotlinx.serialization.json.intOrNull +import kotlinx.serialization.json.jsonArray +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import org.junit.Test + +class ToolTests { + val validator = TypesValidator() + + @Test + fun testTools_functionCallStructuring() { + val schema = + mapOf( + Pair( + "character", + Schema.obj( + mapOf( + Pair("name", Schema.string("the character's full name")), + Pair("gender", Schema.string("the character's gender")), + Pair("weight", Schema.float("the character's weight, in kilograms")), + Pair("height", Schema.float("the character's height, in centimeters")), + Pair( + "favorite_foods", + Schema.array( + Schema.string("the name of a food"), + "a short list of the character's favorite foods" + ) + ), + Pair( + "mother", + Schema.obj( + mapOf(Pair("name", Schema.string("the character's mother's name"))), + description = "information about the character's mother" + ) + ), + ) + ) + ), + ) + val model = + setupModel( + FunctionDeclaration( + name = "getFavoriteColor", + description = + "determines a video game character's favorite color based on their features", + parameters = schema + ), + ) + runBlocking { + val response = + model.generateContent( + "I'm imagining a video game character whose name is sam, but I can't think of the rest of their traits, could you make them up for me and figure out the character's favorite color?" + ) + validator.validateResponse((response)) + assert(response.functionCalls.size == 1) + val call = response.functionCalls[0] + assert(call.name == "getFavoriteColor") + validateSchema(schema, call.args) + } + } + + @Test + fun testTools_basicDecisionMaking() { + val schema = + mapOf( + Pair("character", Schema.string("the character whose favorite color should be obtained")) + ) + val model = + setupModel( + FunctionDeclaration( + name = "getFavoriteColor", + description = "returns the favorite color from a provided character's name", + parameters = schema + ), + FunctionDeclaration( + name = "eatAllSnacks", + description = + "orders a robot to find the kitchen of the provided character by their name, then eat all of their snacks so they get really sad. returns how many snacks were eaten", + parameters = schema + ) + ) + runBlocking { + val response = model.generateContent("what is amy's favorite color?") + validator.validateResponse((response)) + assert(response.functionCalls.size == 1) + val call = response.functionCalls[0] + assert(call.name == "getFavoriteColor") + validateSchema(schema, call.args) + } + } + + /** Ensures the model is capable of a simple question, tool call, response workflow. */ + @Test + fun testTools_BasicToolCall() { + val schema = + mapOf( + Pair("character", Schema.string("the character whose favorite color should be obtained")) + ) + val model = + setupModel( + FunctionDeclaration( + name = "getFavoriteColor", + description = "returns the favorite color from a provided character's name", + parameters = schema + ) + ) + runBlocking { + val question = content { text("what's bob's favorite color?") } + val response = model.generateContent(question) + validator.validateResponse((response)) + assert(response.functionCalls.size == 1) + for (call in response.functionCalls) { + assert(call.name == "getFavoriteColor") + validateSchema(schema, call.args) + assert( + call.args["character"]!!.jsonPrimitive.content.toLowerCasePreservingASCIIRules() == "bob" + ) + model.generateContent( + question, + Content( + role = "model", + parts = + listOf( + call, + ) + ), + Content( + parts = + listOf( + FunctionResponsePart( + id = call.id, + name = call.name, + response = JsonObject(mapOf(Pair("result", JsonPrimitive("green")))) + ), + ) + ) + ) + } + } + } + + /** + * Ensures the model can chain function calls together to reach trivial conclusions. In this case, + * the model needs to use the output of one function call as the input to another. + */ + @Test + fun testTools_sequencingFunctionCalls() { + val nameSchema = + mapOf( + Pair("name", Schema.string("the name of the person whose birth month should be obtained")) + ) + val monthSchema = + mapOf(Pair("month", Schema.string("the month whose color should be obtained"))) + val model = + setupModel( + FunctionDeclaration( + name = "getBirthMonth", + description = "returns a person's birth month based on their name", + parameters = nameSchema + ), + FunctionDeclaration( + name = "getMonthColor", + description = "returns the color for a certain month", + parameters = monthSchema + ) + ) + runBlocking { + val question = content { text("what color is john's birth month") } + val response = model.generateContent(question) + assert(response.functionCalls.size == 1) + val call = response.functionCalls[0] + assert(call.name == "getBirthMonth") + assert(call.args["name"]!!.jsonPrimitive.content.toLowerCasePreservingASCIIRules() == "john") + validateSchema(nameSchema, call.args) + val response2 = + model.generateContent( + question, + Content( + role = "model", + parts = + listOf( + call, + ) + ), + Content( + parts = + listOf( + FunctionResponsePart( + id = call.id, + name = call.name, + response = JsonObject(mapOf(Pair("result", JsonPrimitive("june")))) + ), + ) + ) + ) + validator.validateResponse((response)) + assert(response2.functionCalls.size == 1) + val call2 = response2.functionCalls[0] + assert(call2.name == "getMonthColor") + assert( + call2.args["month"]!!.jsonPrimitive.content.toLowerCasePreservingASCIIRules() == "june" + ) + validateSchema(monthSchema, call2.args) + } + } + + fun validateSchema(schema: Map, args: Map) { + // Model should not call the function with unspecified arguments + assert(schema.keys.containsAll(args.keys)) + for (entry in schema) { + validateSchema(entry.value, args.get(entry.key)) + } + } + + /** Simple schema validation. Not comprehensive, but should detect notable inaccuracy. */ + fun validateSchema(schema: Schema, json: JsonElement?) { + if (json == null) { + assert(schema.nullable == true) + return + } + when (json) { + is JsonNull -> { + assert(schema.nullable == true) + } + is JsonPrimitive -> { + if (schema.type == "INTEGER") { + assert(json.intOrNull != null) + } else if (schema.type == "NUMBER") { + assert(json.doubleOrNull != null) + } else if (schema.type == "BOOLEAN") { + assert(json.booleanOrNull != null) + } else if (schema.type == "STRING") { + assert(json.isString) + } else { + assert(false) + } + } + is JsonObject -> { + assert(schema.type == "OBJECT") + val required = schema.required ?: listOf() + val obj = json.jsonObject + for (entry in schema.properties!!) { + if (obj.containsKey(entry.key)) { + validateSchema(entry.value, obj.get(entry.key)) + } else { + assert(!required.contains(entry.key)) + } + } + } + is JsonArray -> { + assert(schema.type == "ARRAY") + for (e in json.jsonArray) { + validateSchema(schema.items!!, e) + } + } + } + } + + companion object { + @JvmStatic + fun setupModel(vararg functions: FunctionDeclaration): GenerativeModel { + val model = + FirebaseAI.getInstance(app(), GenerativeBackend.vertexAI()) + .generativeModel( + modelName = "gemini-2.0-flash", + toolConfig = + ToolConfig( + functionCallingConfig = FunctionCallingConfig(FunctionCallingConfig.Mode.ANY) + ), + tools = listOf(Tool(functions.toList())), + ) + return model + } + } +} diff --git a/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TypesValidator.kt b/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TypesValidator.kt new file mode 100644 index 00000000000..768b9cf4eca --- /dev/null +++ b/firebase-ai/src/androidTest/kotlin/com/google/firebase/ai/TypesValidator.kt @@ -0,0 +1,46 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.firebase.ai + +import com.google.firebase.ai.type.Candidate +import com.google.firebase.ai.type.Content +import com.google.firebase.ai.type.GenerateContentResponse +import com.google.firebase.ai.type.TextPart + +/** Performs structural validation of various API types */ +class TypesValidator { + + fun validateResponse(response: GenerateContentResponse) { + if (response.candidates.isNotEmpty() && hasText(response.candidates[0].content)) { + assert(response.text!!.isNotEmpty()) + } else if (response.candidates.isNotEmpty()) { + assert(!hasText(response.candidates[0].content)) + } + response.candidates.forEach { validateCandidate(it) } + } + + fun validateCandidate(candidate: Candidate) { + validateContent(candidate.content) + } + + fun validateContent(content: Content) { + assert(content.role != "user") + } + + fun hasText(content: Content): Boolean { + return content.parts.filterIsInstance().isNotEmpty() + } +} diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/CountTokensResponse.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/CountTokensResponse.kt index 955f7bf941a..8909d15fc43 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/CountTokensResponse.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/CountTokensResponse.kt @@ -46,14 +46,14 @@ public class CountTokensResponse( @Serializable internal data class Internal( - val totalTokens: Int, + val totalTokens: Int? = null, val totalBillableCharacters: Int? = null, val promptTokensDetails: List? = null ) : Response { internal fun toPublic(): CountTokensResponse { return CountTokensResponse( - totalTokens, + totalTokens ?: 0, totalBillableCharacters ?: 0, promptTokensDetails?.map { it.toPublic() } ?: emptyList() )