Skip to content

Commit a9ea9e6

Browse files
author
David Motsonashvili
committed
add tests and minor adjustments to support them
1 parent 26b2516 commit a9ea9e6

File tree

18 files changed

+120
-65
lines changed

18 files changed

+120
-65
lines changed

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import com.google.firebase.vertexai.type.Content
2626
import com.google.firebase.vertexai.type.GenerationConfig
2727
import com.google.firebase.vertexai.type.ImagenGenerationConfig
2828
import com.google.firebase.vertexai.type.ImagenSafetySettings
29-
import com.google.firebase.vertexai.type.ImagenModelConfig
3029
import com.google.firebase.vertexai.type.InvalidLocationException
3130
import com.google.firebase.vertexai.type.RequestOptions
3231
import com.google.firebase.vertexai.type.SafetySetting

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/ImageModel.kt

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,24 @@ import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
55
import com.google.firebase.auth.internal.InternalAuthProvider
66
import com.google.firebase.vertexai.common.APIController
77
import com.google.firebase.vertexai.common.HeaderProvider
8+
import com.google.firebase.vertexai.common.PromptBlockedException
89
import com.google.firebase.vertexai.internal.GenerateImageRequest
10+
import com.google.firebase.vertexai.internal.GenerateImageResponse
911
import com.google.firebase.vertexai.internal.ImagenParameters
1012
import com.google.firebase.vertexai.internal.ImagenPromptInstance
1113
import com.google.firebase.vertexai.internal.util.toInternal
1214
import com.google.firebase.vertexai.internal.util.toPublicGCS
1315
import com.google.firebase.vertexai.internal.util.toPublicInline
1416
import com.google.firebase.vertexai.type.FirebaseVertexAIException
15-
import com.google.firebase.vertexai.type.ImagenSafetySettings
1617
import com.google.firebase.vertexai.type.ImagenGCSImage
1718
import com.google.firebase.vertexai.type.ImagenGenerationConfig
1819
import com.google.firebase.vertexai.type.ImagenGenerationResponse
1920
import com.google.firebase.vertexai.type.ImagenInlineImage
20-
import com.google.firebase.vertexai.type.ImagenModelConfig
21-
import com.google.firebase.vertexai.type.PromptBlockedException
21+
import com.google.firebase.vertexai.type.ImagenSafetySettings
2222
import com.google.firebase.vertexai.type.RequestOptions
23-
import kotlinx.coroutines.tasks.await
2423
import kotlin.time.Duration
2524
import kotlin.time.Duration.Companion.seconds
25+
import kotlinx.coroutines.tasks.await
2626

2727
public class ImageModel
2828
internal constructor(
@@ -87,11 +87,12 @@ internal constructor(
8787
),
8888
)
8989

90-
public suspend fun generateImage(
91-
prompt: String,
92-
): ImagenGenerationResponse<ImagenInlineImage> =
90+
public suspend fun generateImage(prompt: String): ImagenGenerationResponse<ImagenInlineImage> =
9391
try {
94-
controller.generateImage(constructRequest(prompt, null, generationConfig)).toPublicInline().validate()
92+
controller
93+
.generateImage(constructRequest(prompt, null, generationConfig))
94+
.validate()
95+
.toPublicInline()
9596
} catch (e: Throwable) {
9697
throw FirebaseVertexAIException.from(e)
9798
}
@@ -101,7 +102,10 @@ internal constructor(
101102
gcsUri: String,
102103
): ImagenGenerationResponse<ImagenGCSImage> =
103104
try {
104-
controller.generateImage(constructRequest(prompt, gcsUri, generationConfig)).toPublicGCS().validate()
105+
controller
106+
.generateImage(constructRequest(prompt, gcsUri, generationConfig))
107+
.validate()
108+
.toPublicGCS()
105109
} catch (e: Throwable) {
106110
throw FirebaseVertexAIException.from(e)
107111
}
@@ -129,14 +133,17 @@ internal constructor(
129133

130134
internal companion object {
131135
private val TAG = ImageModel::class.java.simpleName
132-
internal const val DEFAULT_FILTERED_ERROR = "Unable to show generated images. All images were filtered out because they violated Vertex AI's usage guidelines. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback."
136+
internal const val DEFAULT_FILTERED_ERROR =
137+
"Unable to show generated images. All images were filtered out because they violated Vertex AI's usage guidelines. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback."
133138
}
134139
}
135140

136-
private fun <T> ImagenGenerationResponse<T>.validate():
137-
ImagenGenerationResponse<T> {
138-
if (images.isEmpty()) {
139-
throw PromptBlockedException(message = filteredReason ?: ImageModel.DEFAULT_FILTERED_ERROR)
141+
private fun GenerateImageResponse.validate(): GenerateImageResponse {
142+
if (predictions.none { it.mimeType != null }) {
143+
throw PromptBlockedException(
144+
message = predictions.first { it.raiFilteredReason != null }.raiFilteredReason
145+
?: ImageModel.DEFAULT_FILTERED_ERROR
146+
)
140147
}
141148
return this
142149
}

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,9 @@ private suspend fun validateResponse(response: HttpResponse) {
271271
if (message.contains("quota")) {
272272
throw QuotaExceededException(message)
273273
}
274+
if (message.contains("The prompt could not be submitted")) {
275+
throw PromptBlockedException(message)
276+
}
274277
getServiceDisabledErrorDetailsOrNull(error)?.let {
275278
val errorMessage =
276279
if (it.metadata?.get("service") == "firebasevertexai.googleapis.com") {
@@ -309,6 +312,4 @@ private fun GenerateContentResponse.validate() = apply {
309312
?.let { throw ResponseStoppedException(this) }
310313
}
311314

312-
private fun GenerateImageResponse.validate() = apply {
313-
314-
}
315+
private fun GenerateImageResponse.validate() = apply {}

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Exceptions.kt

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,18 @@ internal class InvalidAPIKeyException(message: String, cause: Throwable? = null)
6565
*
6666
* @property response the full server response for the request.
6767
*/
68-
internal class PromptBlockedException(
69-
val response: GenerateContentResponse,
70-
cause: Throwable? = null
68+
internal class PromptBlockedException
69+
internal constructor(
70+
val response: GenerateContentResponse?,
71+
cause: Throwable? = null,
72+
message: String? = null,
7173
) :
7274
FirebaseCommonAIException(
73-
"Prompt was blocked: ${response.promptFeedback?.blockReason?.name}",
75+
"Prompt was blocked: ${response?.promptFeedback?.blockReason?.name?: message}",
7476
cause,
75-
)
77+
) {
78+
internal constructor(message: String, cause: Throwable? = null) : this(null, cause, message)
79+
}
7680

7781
/**
7882
* The user's location (region) is not supported by the API.

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/GenerateImageResponse.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ internal data class GenerateImageResponse(val predictions: List<ImagenImageRespo
99
internal data class ImagenImageResponse(
1010
val bytesBase64Encoded: String? = null,
1111
val gcsUri: String? = null,
12-
val mimeType: String,
12+
val mimeType: String? = null,
13+
val raiFilteredReason: String? = null,
1314
)

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/internal/util/conversions.kt

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -368,16 +368,22 @@ internal fun com.google.firebase.vertexai.common.CountTokensResponse.toPublic()
368368
CountTokensResponse(totalTokens, totalBillableCharacters ?: 0)
369369

370370
internal fun com.google.firebase.vertexai.internal.GenerateImageResponse.toPublicInline() =
371-
ImagenGenerationResponse(images = predictions.map { it.toPublicInline() }, null)
371+
ImagenGenerationResponse(
372+
images = predictions.filter { it.mimeType != null }.map { it.toPublicInline() },
373+
null,
374+
)
372375

373376
internal fun com.google.firebase.vertexai.internal.ImagenImageResponse.toPublicInline() =
374-
ImagenInlineImage(bytesBase64Encoded!!.toByteArray(), mimeType)
377+
ImagenInlineImage(bytesBase64Encoded!!.toByteArray(), mimeType!!)
375378

376379
internal fun com.google.firebase.vertexai.internal.GenerateImageResponse.toPublicGCS() =
377-
ImagenGenerationResponse(images = predictions.map { it.toPublicGCS() }, null)
380+
ImagenGenerationResponse(
381+
images = predictions.filter { it.mimeType != null }.map { it.toPublicGCS() },
382+
null,
383+
)
378384

379385
internal fun com.google.firebase.vertexai.internal.ImagenImageResponse.toPublicGCS() =
380-
ImagenGCSImage(gcsUri!!, mimeType)
386+
ImagenGCSImage(gcsUri!!, mimeType!!)
381387

382388
internal fun JsonObject.toPublic() = JSONObject(toString())
383389

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/Exceptions.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ internal constructor(message: String, cause: Throwable? = null) : RuntimeExcepti
4545
is com.google.firebase.vertexai.common.InvalidAPIKeyException ->
4646
InvalidAPIKeyException(cause.message ?: "")
4747
is com.google.firebase.vertexai.common.PromptBlockedException ->
48-
PromptBlockedException(cause.response.toPublic(), cause.cause)
48+
PromptBlockedException(cause.response?.toPublic(), cause.cause)
4949
is com.google.firebase.vertexai.common.UnsupportedUserLocationException ->
5050
UnsupportedUserLocationException(cause.cause)
5151
is com.google.firebase.vertexai.common.InvalidStateException ->

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/type/ImagenAspectRatio.kt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ package com.google.firebase.vertexai.type
22

33
@Suppress("EnumEntryName")
44
public enum class ImagenAspectRatio(internal val internalVal: String) {
5-
SQUARE_1x1("1:1"),
6-
PORTRAIT_3x4("3:4"),
7-
LANDSCAPE_4x3("4:3"),
8-
PORTRAIT_9x16("9:16"),
9-
LANDSCAPE_16x9("16:9")
5+
SQUARE_1x1("1:1"),
6+
PORTRAIT_3x4("3:4"),
7+
LANDSCAPE_4x3("4:3"),
8+
PORTRAIT_9x16("9:16"),
9+
LANDSCAPE_16x9("16:9")
1010
}
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package com.google.firebase.vertexai.type
22

33
public class ImagenGenerationConfig(
4-
public val negativePrompt: String? = null,
5-
public val numberOfImages: Int? = 1,
6-
public val aspectRatio: ImagenAspectRatio? = null,
7-
public val imageFormat: ImagenImageFormat? = null,
8-
public val addWatermark: Boolean? = null
4+
public val negativePrompt: String? = null,
5+
public val numberOfImages: Int? = 1,
6+
public val aspectRatio: ImagenAspectRatio? = null,
7+
public val imageFormat: ImagenImageFormat? = null,
8+
public val addWatermark: Boolean? = null
99
) {}
Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
11
package com.google.firebase.vertexai.type
22

3-
public class ImagenModelConfig private constructor(
4-
5-
) {
6-
7-
}
3+
public class ImagenModelConfig private constructor() {}

0 commit comments

Comments
 (0)