Skip to content

Commit 5822579

Browse files
davidmotsonDavid Motsonashvili
andauthored
add safety scores to image generation (#7322)
hard coded to return image safety scores, as well as parse and provide these to the developer. --------- Co-authored-by: David Motsonashvili <[email protected]>
1 parent b986d2f commit 5822579

File tree

5 files changed

+31
-2
lines changed

5 files changed

+31
-2
lines changed

firebase-ai/src/main/kotlin/com/google/firebase/ai/ImagenModel.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ internal constructor(
177177
GenerateImageRequest.ImagenParameters(
178178
sampleCount = generationConfig?.numberOfImages ?: 1,
179179
includeRaiReason = true,
180+
includeSafetyAttributes = true,
180181
addWatermark = generationConfig?.addWatermark,
181182
personGeneration = safetySettings?.personFilterLevel?.internalVal,
182183
negativePrompt = generationConfig?.negativePrompt,
@@ -206,6 +207,7 @@ internal constructor(
206207
GenerateImageRequest.ImagenParameters(
207208
sampleCount = generationConfig?.numberOfImages ?: 1,
208209
includeRaiReason = true,
210+
includeSafetyAttributes = true,
209211
addWatermark = generationConfig?.addWatermark,
210212
personGeneration = safetySettings?.personFilterLevel?.internalVal,
211213
negativePrompt = generationConfig?.negativePrompt,

firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ internal data class GenerateImageRequest(
9393
internal data class ImagenParameters(
9494
val sampleCount: Int,
9595
val includeRaiReason: Boolean,
96+
val includeSafetyAttributes: Boolean,
9697
val storageUri: String?,
9798
val negativePrompt: String?,
9899
val aspectRatio: String?,

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenGenerationResponse.kt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ internal constructor(public val images: List<T>, public val filteredReason: Stri
4242
internal fun toPublicInline() =
4343
ImagenGenerationResponse(
4444
images = predictions.filter { it.mimeType != null }.map { it.toPublicInline() },
45-
null,
45+
predictions.firstNotNullOfOrNull { it.raiFilteredReason },
4646
)
4747
}
4848

@@ -52,10 +52,24 @@ internal constructor(public val images: List<T>, public val filteredReason: Stri
5252
val gcsUri: String? = null,
5353
val mimeType: String? = null,
5454
val raiFilteredReason: String? = null,
55+
val safetyAttributes: ImagenSafetyAttributes? = null,
5556
) {
5657
internal fun toPublicInline() =
5758
ImagenInlineImage(Base64.decode(bytesBase64Encoded!!, Base64.NO_WRAP), mimeType!!)
5859

5960
internal fun toPublicGCS() = ImagenGCSImage(gcsUri!!, mimeType!!)
6061
}
62+
63+
@Serializable
64+
internal data class ImagenSafetyAttributes(
65+
val categories: List<String>? = null,
66+
val scores: List<Double>? = null
67+
) {
68+
internal fun toPublic(): Map<String, Double> {
69+
if (categories == null || scores == null) {
70+
return emptyMap()
71+
}
72+
return categories.zip(scores).toMap()
73+
}
74+
}
6175
}

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenInlineImage.kt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ import kotlinx.serialization.Serializable
3131
*/
3232
@PublicPreviewAPI
3333
public class ImagenInlineImage
34-
internal constructor(public val data: ByteArray, public val mimeType: String) {
34+
internal constructor(
35+
public val data: ByteArray,
36+
public val mimeType: String,
37+
) {
3538

3639
/**
3740
* Returns the image as an Android OS native [Bitmap] so that it can be saved or sent to the UI.

firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,15 @@ internal class VertexAIUnarySnapshotTests {
609609
}
610610
}
611611

612+
@Test
613+
fun `generateImages should contain safety data`() =
614+
goldenVertexUnaryFile("unary-success-generate-images-safety_info.json") {
615+
withTimeout(testTimeout) {
616+
val response = imagenModel.generateImages("prompt")
617+
// There is no public API, but if it parses then success
618+
}
619+
}
620+
612621
@Test
613622
fun `google search grounding metadata is parsed correctly`() =
614623
goldenVertexUnaryFile("unary-success-google-search-grounding.json") {

0 commit comments

Comments
 (0)