Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 77 additions & 9 deletions firebase-ai/src/main/kotlin/com/google/firebase/ai/ImagenModel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.google.firebase.ai.common.AppCheckHeaderProvider
import com.google.firebase.ai.common.ContentBlockedException
import com.google.firebase.ai.common.GenerateImageRequest
import com.google.firebase.ai.type.FirebaseAIException
import com.google.firebase.ai.type.ImagenEditingConfig
import com.google.firebase.ai.type.ImagenGenerationConfig
import com.google.firebase.ai.type.ImagenGenerationResponse
import com.google.firebase.ai.type.ImagenInlineImage
Expand Down Expand Up @@ -75,30 +76,97 @@ internal constructor(
public suspend fun generateImages(prompt: String): ImagenGenerationResponse<ImagenInlineImage> =
try {
controller
.generateImage(constructRequest(prompt, null, generationConfig))
.generateImage(constructGenerationRequest(prompt, null, generationConfig))
.validate()
.toPublicInline()
} catch (e: Throwable) {
throw FirebaseAIException.from(e)
}

private fun constructRequest(
public suspend fun editImage(
prompt: String,
gcsUri: String?,
config: ImagenGenerationConfig?,
config: ImagenEditingConfig
): ImagenGenerationResponse<ImagenInlineImage> =
try {
controller
.generateImage(constructEditRequest(prompt, null, config))
.validate()
.toPublicInline()
} catch (e: Throwable) {
throw FirebaseAIException.from(e)
}

private fun constructGenerationRequest(
prompt: String,
gcsUri: String? = null,
generationConfig: ImagenGenerationConfig? = null,
): GenerateImageRequest {
return GenerateImageRequest(
listOf(GenerateImageRequest.ImagenPrompt(prompt)),
GenerateImageRequest.ImagenParameters(
sampleCount = config?.numberOfImages ?: 1,
sampleCount = generationConfig?.numberOfImages ?: 1,
includeRaiReason = true,
addWatermark = this.generationConfig?.addWatermark,
personGeneration = safetySettings?.personFilterLevel?.internalVal,
negativePrompt = generationConfig?.negativePrompt,
safetySetting = safetySettings?.safetyFilterLevel?.internalVal,
storageUri = gcsUri,
aspectRatio = generationConfig?.aspectRatio?.internalVal,
imageOutputOptions = this.generationConfig?.imageFormat?.toInternal(),
editMode = null,
editConfig = null
),
)
}

private fun constructEditRequest(
prompt: String,
gcsUri: String? = null,
editConfig: ImagenEditingConfig,
): GenerateImageRequest {
return GenerateImageRequest(
listOf(
GenerateImageRequest.ImagenPrompt(
prompt = prompt,
referenceImages =
buildList {
add(
GenerateImageRequest.ReferenceImage(
referenceType = GenerateImageRequest.ReferenceType.RAW,
referenceId = 1,
referenceImage = editConfig.image.toInternal(),
maskImageConfig = null
)
)
if (editConfig.mask != null) {
add(
GenerateImageRequest.ReferenceImage(
referenceType = GenerateImageRequest.ReferenceType.MASK,
referenceId = 2,
referenceImage = editConfig.mask.toInternal(),
maskImageConfig =
GenerateImageRequest.MaskImageConfig(
maskMode = GenerateImageRequest.MaskMode.USER_PROVIDED,
dilation = editConfig.maskDilation
)
)
)
}
}
)
),
GenerateImageRequest.ImagenParameters(
sampleCount = generationConfig?.numberOfImages ?: 1,
includeRaiReason = true,
addWatermark = generationConfig?.addWatermark,
addWatermark = this.generationConfig?.addWatermark,
personGeneration = safetySettings?.personFilterLevel?.internalVal,
negativePrompt = config?.negativePrompt,
negativePrompt = generationConfig?.negativePrompt,
safetySetting = safetySettings?.safetyFilterLevel?.internalVal,
storageUri = gcsUri,
aspectRatio = config?.aspectRatio?.internalVal,
imageOutputOptions = generationConfig?.imageFormat?.toInternal(),
aspectRatio = generationConfig?.aspectRatio?.internalVal,
imageOutputOptions = this.generationConfig?.imageFormat?.toInternal(),
editMode = editConfig.editMode.value,
editConfig = editConfig.toInternal()
),
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import com.google.firebase.ai.common.util.fullModelName
import com.google.firebase.ai.common.util.trimmedModelName
import com.google.firebase.ai.type.Content
import com.google.firebase.ai.type.GenerationConfig
import com.google.firebase.ai.type.ImagenEditingConfig
import com.google.firebase.ai.type.ImagenImageFormat
import com.google.firebase.ai.type.ImagenInlineImage
import com.google.firebase.ai.type.PublicPreviewAPI
import com.google.firebase.ai.type.SafetySetting
import com.google.firebase.ai.type.Tool
Expand Down Expand Up @@ -75,23 +77,62 @@ internal data class CountTokensRequest(
}

@Serializable
@PublicPreviewAPI
internal data class GenerateImageRequest(
val instances: List<ImagenPrompt>,
val parameters: ImagenParameters,
) : Request {
@Serializable internal data class ImagenPrompt(val prompt: String)
@Serializable
internal data class ImagenPrompt(
val prompt: String? = null,
val image: ImagenInlineImage.Internal? = null,
val referenceImages: List<ReferenceImage>? = null
)

@OptIn(PublicPreviewAPI::class)
@Serializable
internal data class ImagenParameters(
val sampleCount: Int,
val includeRaiReason: Boolean,
val includeRaiReason: Boolean?,
val storageUri: String?,
val negativePrompt: String?,
val aspectRatio: String?,
val safetySetting: String?,
val personGeneration: String?,
val addWatermark: Boolean?,
val imageOutputOptions: ImagenImageFormat.Internal?,
val editMode: String?,
val editConfig: ImagenEditingConfig.Internal?,
)

@Serializable
internal enum class ReferenceType {
@SerialName("REFERENCE_TYPE_UNSPECIFIED") UNSPECIFIED,
@SerialName("REFERENCE_TYPE_RAW") RAW,
@SerialName("REFERENCE_TYPE_MASK") MASK,
@SerialName("REFERENCE_TYPE_CONTROL") CONTROL,
@SerialName("REFERENCE_TYPE_STYLE") STYLE,
@SerialName("REFERENCE_TYPE_SUBJECT") SUBJECT,
@SerialName("REFERENCE_TYPE_MASKED_SUBJECT") MASKED_SUBJECT,
@SerialName("REFERENCE_TYPE_PRODUCT") PRODUCT
}

@Serializable
internal enum class MaskMode {
@SerialName("MASK_MODE_DEFAULT") DEFAULT,
@SerialName("MASK_MODE_USER_PROVIDED") USER_PROVIDED,
@SerialName("MASK_MODE_BACKGROUND") BACKGROUND,
@SerialName("MASK_MODE_FOREGROUND") FOREGROUND,
@SerialName("MASK_MODE_SEMANTIC") SEMANTIC
}

@Serializable internal data class MaskImageConfig(val maskMode: MaskMode, val dilation: Double?)

@Serializable
internal data class ReferenceImage(
val referenceType: ReferenceType,
val referenceId: Int,
val referenceImage: ImagenInlineImage.Internal,
val maskImageConfig: MaskImageConfig?
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.google.firebase.ai.type

public class ImagenEditMode private constructor(internal val value: String) {

public companion object {
public val INPAINT_INSERTION: ImagenEditMode = ImagenEditMode("EDIT_MODE_INPAINT_INSERTION")
public val INPAINT_REMOVAL: ImagenEditMode = ImagenEditMode("EDIT_MODE_INPAINT_REMOVAL")
public val OUTPAINT: ImagenEditMode = ImagenEditMode("EDIT_MODE_OUTPAINT")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package com.google.firebase.ai.type

import kotlinx.serialization.Serializable

@PublicPreviewAPI
public class ImagenEditingConfig(
public val image: ImagenInlineImage,
public val editMode: ImagenEditMode,
public val mask: ImagenInlineImage? = null,
public val maskDilation: Double? = null,
public val editSteps: Int? = null,
) {
public companion object {
public fun builder(): Builder = Builder()
}

public class Builder {
@JvmField public var image: ImagenInlineImage? = null
@JvmField public var editMode: ImagenEditMode? = null
@JvmField public var mask: ImagenInlineImage? = null
@JvmField public var maskDilation: Double? = null
@JvmField public var editSteps: Int? = null

public fun setImage(image: ImagenInlineImage): Builder = apply { this.image = image }

public fun setEditMode(editMode: ImagenEditMode): Builder = apply { this.editMode = editMode }

public fun setMask(mask: ImagenInlineImage): Builder = apply { this.mask = mask }

public fun setMaskDilation(maskDilation: Double): Builder = apply {
this.maskDilation = maskDilation
}

public fun setEditSteps(editSteps: Int): Builder = apply { this.editSteps = editSteps }

public fun build(): ImagenEditingConfig {
if (image == null) {
throw IllegalStateException("ImagenEditingConfig must contain an image")
}
if (editMode == null) {
throw IllegalStateException("ImagenEditingConfig must contain an editMode")
}
return ImagenEditingConfig(
image = image!!,
editMode = editMode!!,
mask = mask,
maskDilation = maskDilation,
editSteps = editSteps,
)
}
}

internal fun toInternal(): Internal {
return Internal(baseSteps = editSteps)
}

@Serializable
internal data class Internal(
val baseSteps: Int?,
)
}

@PublicPreviewAPI
public fun imagenEditingConfig(init: ImagenEditingConfig.Builder.() -> Unit): ImagenEditingConfig {
val builder = ImagenEditingConfig.builder()
builder.init()
return builder.build()
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ package com.google.firebase.ai.type

import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.util.Base64
import java.io.ByteArrayOutputStream
import kotlinx.serialization.Serializable

/**
* Represents an Imagen-generated image that is returned as inline data.
Expand All @@ -36,4 +39,19 @@ internal constructor(public val data: ByteArray, public val mimeType: String) {
public fun asBitmap(): Bitmap {
return BitmapFactory.decodeByteArray(data, 0, data.size)
}

@Serializable internal data class Internal(val bytesBase64Encoded: String)

internal fun toInternal(): Internal {
val base64 = Base64.encodeToString(data, Base64.NO_WRAP)
return Internal(base64)
}
}

@PublicPreviewAPI
public fun Bitmap.toImagenImage(): ImagenInlineImage {
val byteArrayOutputStream = ByteArrayOutputStream()
this.compress(Bitmap.CompressFormat.PNG, 100, byteArrayOutputStream)
val byteArray = byteArrayOutputStream.toByteArray()
return ImagenInlineImage(data = byteArray, mimeType = "image/png")
}
Loading