Skip to content

Add imagen editing cases to the Firebase AI quickstart #2702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
package com.google.firebase.quickstart.ai

import android.graphics.Bitmap
import com.google.firebase.ai.ImagenModel
import com.google.firebase.ai.type.Dimensions
import com.google.firebase.ai.type.FunctionDeclaration
import com.google.firebase.ai.type.GenerativeBackend
import com.google.firebase.ai.type.ImagenBackgroundMask
import com.google.firebase.ai.type.ImagenEditMode
import com.google.firebase.ai.type.ImagenEditingConfig
import com.google.firebase.ai.type.ImagenMaskReference
import com.google.firebase.ai.type.ImagenRawImage
import com.google.firebase.ai.type.PublicPreviewAPI
import com.google.firebase.ai.type.ResponseModality
import com.google.firebase.ai.type.Schema
import com.google.firebase.ai.type.Tool
import com.google.firebase.ai.type.content
import com.google.firebase.ai.type.generationConfig
import com.google.firebase.ai.type.toImagenInlineImage
import com.google.firebase.quickstart.ai.ui.navigation.Category
import com.google.firebase.quickstart.ai.ui.navigation.Sample

@OptIn(PublicPreviewAPI::class)
val FIREBASE_AI_SAMPLES = listOf(
Sample(
title = "Travel tips",
Expand Down Expand Up @@ -133,6 +144,49 @@ val FIREBASE_AI_SAMPLES = listOf(
)
}
),
Sample(
title = "Imagen 3 - Inpainting",
description = "Replace the background of an image using Imagen 3",
modelName= "imagen-3.0-capability-001",
backend = GenerativeBackend.vertexAI(),
navRoute = "imagen",
categories = listOf(Category.IMAGE),
initialPrompt = content {
text(
"A sunny beach"
)
},
includeAttach = true,
generateImages = { model: ImagenModel, inputText: String, bitmap: Bitmap? ->
model.editImage(
listOf(ImagenRawImage(bitmap!!.toImagenInlineImage()), ImagenBackgroundMask()),
inputText,
ImagenEditingConfig(ImagenEditMode.INPAINT_INSERTION)
)
}
),
Sample(
title = "Imagen 3 - Outpainting",
description = "Expand an image by drawing in more background",
modelName= "imagen-3.0-capability-001",
backend = GenerativeBackend.vertexAI(),
navRoute = "imagen",
categories = listOf(Category.IMAGE),
initialPrompt = content {
text(
""
)
},
includeAttach = true,
generateImages = { model: ImagenModel, inputText: String, bitmap: Bitmap? ->
val dimensions = Dimensions(bitmap!!.width * 2, bitmap.height * 2)
model.editImage(
ImagenMaskReference.generateMaskAndPadForOutpainting(bitmap.toImagenInlineImage(), dimensions),
inputText,
ImagenEditingConfig(ImagenEditMode.OUTPAINT)
)
}
),
Sample(
title = "Gemini 2.0 Flash - image generation",
description = "Generate and/or edit images using Gemini 2.0 Flash",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class MainActivity : ComponentActivity() {
composable<ChatRoute> {
ChatScreen()
}
// Imagen Samples
// Imagn Samples
composable<ImagenRoute> {
ImagenScreen()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package com.google.firebase.quickstart.ai.feature.media.imagen

import android.net.Uri
import android.provider.OpenableColumns
import android.text.format.Formatter
import androidx.activity.compose.rememberLauncherForActivityResult
import androidx.activity.result.contract.ActivityResultContracts
import androidx.compose.foundation.Image
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
Expand All @@ -24,6 +29,7 @@ import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.asImageBitmap
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.unit.dp
import androidx.lifecycle.compose.collectAsStateWithLifecycle
import androidx.lifecycle.viewmodel.compose.viewModel
Expand All @@ -40,6 +46,29 @@ fun ImagenScreen(
val errorMessage by imagenViewModel.errorMessage.collectAsStateWithLifecycle()
val isLoading by imagenViewModel.isLoading.collectAsStateWithLifecycle()
val generatedImages by imagenViewModel.generatedBitmaps.collectAsStateWithLifecycle()
val includeAttach by imagenViewModel.includeAttach.collectAsStateWithLifecycle()
val context = LocalContext.current
val contentResolver = context.contentResolver
val openDocument = rememberLauncherForActivityResult(ActivityResultContracts.OpenDocument()) { optionalUri: Uri? ->
optionalUri?.let { uri ->
var fileName: String? = null
// Fetch file name and size
contentResolver.query(uri, null, null, null, null)?.use { cursor ->
val nameIndex = cursor.getColumnIndex(OpenableColumns.DISPLAY_NAME)
val sizeIndex = cursor.getColumnIndex(OpenableColumns.SIZE)
cursor.moveToFirst()
val humanReadableSize = Formatter.formatShortFileSize(
context, cursor.getLong(sizeIndex)
)
fileName = "${cursor.getString(nameIndex)} ($humanReadableSize)"
}

contentResolver.openInputStream(uri)?.use { stream ->
val bytes = stream.readBytes()
imagenViewModel.attachImage(bytes, fileName)
}
}
}

Column(
modifier = Modifier
Expand All @@ -59,11 +88,21 @@ fun ImagenScreen(
.padding(16.dp)
.fillMaxWidth()
)
if (includeAttach) {
TextButton(
onClick = {
openDocument.launch(arrayOf("image/*"))
},
modifier = Modifier
.padding(end = 16.dp, bottom = 16.dp)
.align(Alignment.End)


) { Text("Attach") }
}
TextButton(
onClick = {
if (imagenPrompt.isNotBlank()) {
imagenViewModel.generateImages(imagenPrompt)
}
imagenViewModel.generateImages(imagenPrompt)
},
modifier = Modifier
.padding(end = 16.dp, bottom = 16.dp)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package com.google.firebase.quickstart.ai.feature.media.imagen

import android.graphics.Bitmap
import android.graphics.BitmapFactory
import androidx.lifecycle.SavedStateHandle
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import androidx.navigation.toRoute
import com.google.firebase.Firebase
import com.google.firebase.ai.ImagenModel
import com.google.firebase.ai.ai
import com.google.firebase.ai.type.GenerativeBackend
import com.google.firebase.ai.type.ImagenAspectRatio
import com.google.firebase.ai.type.ImagenEditMode
import com.google.firebase.ai.type.ImagenEditingConfig
import com.google.firebase.ai.type.ImagenImageFormat
import com.google.firebase.ai.type.ImagenPersonFilterLevel
import com.google.firebase.ai.type.ImagenSafetyFilterLevel
Expand All @@ -36,38 +38,47 @@ class ImagenViewModel(
private val _isLoading = MutableStateFlow(false)
val isLoading: StateFlow<Boolean> = _isLoading

private val _includeAttach = MutableStateFlow(sample.includeAttach)
val includeAttach: StateFlow<Boolean> = _includeAttach

private val _generatedBitmaps = MutableStateFlow(listOf<Bitmap>())
val generatedBitmaps: StateFlow<List<Bitmap>> = _generatedBitmaps

// Firebase AI Logic
private val imagenModel: ImagenModel
private var attachedImage: Bitmap?

init {
val config = imagenGenerationConfig {
numberOfImages = 4
aspectRatio = ImagenAspectRatio.SQUARE_1x1
imageFormat = ImagenImageFormat.png()
}
val settings = ImagenSafetySettings(
safetyFilterLevel = ImagenSafetyFilterLevel.BLOCK_LOW_AND_ABOVE,
personFilterLevel = ImagenPersonFilterLevel.BLOCK_ALL
)
imagenModel = Firebase.ai(
backend = GenerativeBackend.googleAI()
backend = sample.backend
).imagenModel(
modelName = sample.modelName ?: "imagen-3.0-generate-002",
generationConfig = config,
safetySettings = settings
)
attachedImage = null
}

fun generateImages(inputText: String) {
viewModelScope.launch {
_isLoading.value = true
try {
val imageResponse = imagenModel.generateImages(
inputText
)
val generateImages = sample.generateImages
val imageResponse = if (generateImages == null) {
imagenModel.generateImages(
inputText
)
} else {
generateImages(imagenModel, inputText, attachedImage)
}
_generatedBitmaps.value = imageResponse.images.map { it.asBitmap() }
_errorMessage.value = null // clear error message
} catch (e: Exception) {
Expand All @@ -77,4 +88,11 @@ class ImagenViewModel(
}
}
}

fun attachImage(
fileInBytes: ByteArray,
fileName: String? = "Unnamed file"
) {
attachedImage = BitmapFactory.decodeByteArray(fileInBytes, 0, fileInBytes.size)
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
package com.google.firebase.quickstart.ai.ui.navigation

import android.graphics.Bitmap
import com.google.firebase.ai.ImagenModel
import com.google.firebase.ai.type.Content
import com.google.firebase.ai.type.GenerationConfig
import com.google.firebase.ai.type.GenerativeBackend
import com.google.firebase.ai.type.ImagenGenerationResponse
import com.google.firebase.ai.type.ImagenInlineImage
import com.google.firebase.ai.type.PublicPreviewAPI
import com.google.firebase.ai.type.Tool
import java.util.UUID

Expand All @@ -17,6 +22,7 @@ enum class Category(
FUNCTION_CALLING("Function calling"),
}

@OptIn(PublicPreviewAPI::class)
data class Sample(
val id: String = UUID.randomUUID().toString(), // used for navigation
val title: String,
Expand All @@ -30,5 +36,7 @@ data class Sample(
val systemInstructions: Content? = null,
val generationConfig: GenerationConfig? = null,
val chatHistory: List<Content> = emptyList(),
val tools: List<Tool>? = null
val tools: List<Tool>? = null,
val includeAttach: Boolean = false,
val generateImages: (suspend (ImagenModel, String, Bitmap?) -> ImagenGenerationResponse<ImagenInlineImage>)? = null
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than passing a suspend function here, what if we pass the List<ImagenReferenceImage> and prompt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That may not work as an abstraction, sometimes editing settings also need to be set, basically every parameter, depending on the use case.

)
Loading