Skip to content

Commit d23a093

Browse files
author
David Motsonashvili
committed
add safety scores to image generation
1 parent 546cbcb commit d23a093

File tree

4 files changed

+26
-4
lines changed

4 files changed

+26
-4
lines changed

firebase-ai/src/main/kotlin/com/google/firebase/ai/ImagenModel.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ internal constructor(
177177
GenerateImageRequest.ImagenParameters(
178178
sampleCount = generationConfig?.numberOfImages ?: 1,
179179
includeRaiReason = true,
180+
includeSafetyAttributes = true,
180181
addWatermark = generationConfig?.addWatermark,
181182
personGeneration = safetySettings?.personFilterLevel?.internalVal,
182183
negativePrompt = generationConfig?.negativePrompt,
@@ -206,6 +207,7 @@ internal constructor(
206207
GenerateImageRequest.ImagenParameters(
207208
sampleCount = generationConfig?.numberOfImages ?: 1,
208209
includeRaiReason = true,
210+
includeSafetyAttributes = true,
209211
addWatermark = generationConfig?.addWatermark,
210212
personGeneration = safetySettings?.personFilterLevel?.internalVal,
211213
negativePrompt = generationConfig?.negativePrompt,

firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ internal data class GenerateImageRequest(
9393
internal data class ImagenParameters(
9494
val sampleCount: Int,
9595
val includeRaiReason: Boolean,
96+
val includeSafetyAttributes: Boolean,
9697
val storageUri: String?,
9798
val negativePrompt: String?,
9899
val aspectRatio: String?,

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenGenerationResponse.kt

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ internal constructor(public val images: List<T>, public val filteredReason: Stri
4242
internal fun toPublicInline() =
4343
ImagenGenerationResponse(
4444
images = predictions.filter { it.mimeType != null }.map { it.toPublicInline() },
45-
null,
45+
predictions.firstNotNullOfOrNull { it.raiFilteredReason },
4646
)
4747
}
4848

@@ -52,10 +52,25 @@ internal constructor(public val images: List<T>, public val filteredReason: Stri
5252
val gcsUri: String? = null,
5353
val mimeType: String? = null,
5454
val raiFilteredReason: String? = null,
55+
val safetyAttributes: ImagenSafetyAttributes? = null,
5556
) {
5657
internal fun toPublicInline() =
57-
ImagenInlineImage(Base64.decode(bytesBase64Encoded!!, Base64.NO_WRAP), mimeType!!)
58+
ImagenInlineImage(
59+
Base64.decode(bytesBase64Encoded!!, Base64.NO_WRAP),
60+
mimeType!!,
61+
safetyAttributes?.toPublic() ?: emptyMap()
62+
)
5863

5964
internal fun toPublicGCS() = ImagenGCSImage(gcsUri!!, mimeType!!)
6065
}
66+
67+
@Serializable
68+
internal data class ImagenSafetyAttributes(
69+
val categories: List<String>? = null,
70+
val scores: List<Double>? = null
71+
){
72+
internal fun toPublic(): Map<String, Double> {
73+
return categories?.zip(scores!!)?.toMap() ?: emptyMap()
74+
}
75+
}
6176
}

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenInlineImage.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,14 @@ import kotlinx.serialization.Serializable
2828
* @property data The raw image bytes in JPEG or PNG format, as specified by [mimeType].
2929
* @property mimeType The IANA standard MIME type of the image data; either `"image/png"` or
3030
* `"image/jpeg"`; to request a different format, see [ImagenGenerationConfig.imageFormat].
31+
* @property safetyAttributes a set of safety attributes with their associated score.
3132
*/
3233
@PublicPreviewAPI
3334
public class ImagenInlineImage
34-
internal constructor(public val data: ByteArray, public val mimeType: String) {
35+
internal constructor(
36+
public val data: ByteArray,
37+
public val mimeType: String,
38+
public val safetyAttributes: Map<String, Double>) {
3539

3640
/**
3741
* Returns the image as an Android OS native [Bitmap] so that it can be saved or sent to the UI.
@@ -53,5 +57,5 @@ public fun Bitmap.toImagenInlineImage(): ImagenInlineImage {
5357
val byteArrayOutputStream = ByteArrayOutputStream()
5458
this.compress(Bitmap.CompressFormat.JPEG, 80, byteArrayOutputStream)
5559
val byteArray = byteArrayOutputStream.toByteArray()
56-
return ImagenInlineImage(data = byteArray, mimeType = "image/jpeg")
60+
return ImagenInlineImage(data = byteArray, mimeType = "image/jpeg", safetyAttributes = emptyMap())
5761
}

0 commit comments

Comments
 (0)