|
| 1 | +package com.google.firebase.vertexai |
| 2 | + |
| 3 | +import android.util.Log |
| 4 | +import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider |
| 5 | +import com.google.firebase.auth.internal.InternalAuthProvider |
| 6 | +import com.google.firebase.vertexai.common.APIController |
| 7 | +import com.google.firebase.vertexai.common.HeaderProvider |
| 8 | +import com.google.firebase.vertexai.internal.GenerateImageRequest |
| 9 | +import com.google.firebase.vertexai.internal.ImagenParameters |
| 10 | +import com.google.firebase.vertexai.internal.ImagenPromptInstance |
| 11 | +import com.google.firebase.vertexai.internal.util.toInternal |
| 12 | +import com.google.firebase.vertexai.internal.util.toPublicGCS |
| 13 | +import com.google.firebase.vertexai.internal.util.toPublicInline |
| 14 | +import com.google.firebase.vertexai.type.FirebaseVertexAIException |
| 15 | +import com.google.firebase.vertexai.type.ImageSafetySettings |
| 16 | +import com.google.firebase.vertexai.type.ImagenGCSImage |
| 17 | +import com.google.firebase.vertexai.type.ImagenGenerationConfig |
| 18 | +import com.google.firebase.vertexai.type.ImagenGenerationResponse |
| 19 | +import com.google.firebase.vertexai.type.ImagenImageRepresentible |
| 20 | +import com.google.firebase.vertexai.type.ImagenInlineImage |
| 21 | +import com.google.firebase.vertexai.type.ImagenModelConfig |
| 22 | +import com.google.firebase.vertexai.type.PromptBlockedException |
| 23 | +import com.google.firebase.vertexai.type.RequestOptions |
| 24 | +import kotlinx.coroutines.tasks.await |
| 25 | +import kotlin.time.Duration |
| 26 | +import kotlin.time.Duration.Companion.seconds |
| 27 | + |
| 28 | +public class ImageModel |
| 29 | +internal constructor( |
| 30 | + private val modelName: String, |
| 31 | + private val generationConfig: ImagenModelConfig? = null, |
| 32 | + private val safetySettings: ImageSafetySettings? = null, |
| 33 | + private val controller: APIController, |
| 34 | +) { |
| 35 | + @JvmOverloads |
| 36 | + internal constructor( |
| 37 | + modelName: String, |
| 38 | + apiKey: String, |
| 39 | + generationConfig: ImagenModelConfig? = null, |
| 40 | + safetySettings: ImageSafetySettings? = null, |
| 41 | + requestOptions: RequestOptions = RequestOptions(), |
| 42 | + appCheckTokenProvider: InteropAppCheckTokenProvider? = null, |
| 43 | + internalAuthProvider: InternalAuthProvider? = null, |
| 44 | + ) : this( |
| 45 | + modelName, |
| 46 | + generationConfig, |
| 47 | + safetySettings, |
| 48 | + APIController( |
| 49 | + apiKey, |
| 50 | + modelName, |
| 51 | + requestOptions, |
| 52 | + "gl-kotlin/${KotlinVersion.CURRENT} fire/${BuildConfig.VERSION_NAME}", |
| 53 | + object : HeaderProvider { |
| 54 | + override val timeout: Duration |
| 55 | + get() = 10.seconds |
| 56 | + |
| 57 | + override suspend fun generateHeaders(): Map<String, String> { |
| 58 | + val headers = mutableMapOf<String, String>() |
| 59 | + if (appCheckTokenProvider == null) { |
| 60 | + Log.w(TAG, "AppCheck not registered, skipping") |
| 61 | + } else { |
| 62 | + val token = appCheckTokenProvider.getToken(false).await() |
| 63 | + |
| 64 | + if (token.error != null) { |
| 65 | + Log.w(TAG, "Error obtaining AppCheck token", token.error) |
| 66 | + } |
| 67 | + // The Firebase App Check backend can differentiate between apps without App Check, and |
| 68 | + // wrongly configured apps by verifying the value of the token, so it always needs to be |
| 69 | + // included. |
| 70 | + headers["X-Firebase-AppCheck"] = token.token |
| 71 | + } |
| 72 | + |
| 73 | + if (internalAuthProvider == null) { |
| 74 | + Log.w(TAG, "Auth not registered, skipping") |
| 75 | + } else { |
| 76 | + try { |
| 77 | + val token = internalAuthProvider.getAccessToken(false).await() |
| 78 | + |
| 79 | + headers["Authorization"] = "Firebase ${token.token!!}" |
| 80 | + } catch (e: Exception) { |
| 81 | + Log.w(TAG, "Error getting Auth token ", e) |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + return headers |
| 86 | + } |
| 87 | + }, |
| 88 | + ), |
| 89 | + ) |
| 90 | + |
| 91 | + public suspend fun generateImage( |
| 92 | + prompt: String, |
| 93 | + config: ImagenGenerationConfig?, |
| 94 | + ): ImagenGenerationResponse<ImagenInlineImage> = |
| 95 | + try { |
| 96 | + controller.generateImage(constructRequest(prompt, null, config)).toPublicInline().validate() |
| 97 | + } catch (e: Throwable) { |
| 98 | + throw FirebaseVertexAIException.from(e) |
| 99 | + } |
| 100 | + |
| 101 | + public suspend fun generateImage( |
| 102 | + prompt: String, |
| 103 | + gcsUri: String, |
| 104 | + config: ImagenGenerationConfig?, |
| 105 | + ): ImagenGenerationResponse<ImagenGCSImage> = |
| 106 | + try { |
| 107 | + controller.generateImage(constructRequest(prompt, gcsUri, config)).toPublicGCS().validate() |
| 108 | + } catch (e: Throwable) { |
| 109 | + throw FirebaseVertexAIException.from(e) |
| 110 | + } |
| 111 | + |
| 112 | + private fun constructRequest( |
| 113 | + prompt: String, |
| 114 | + gcsUri: String?, |
| 115 | + config: ImagenGenerationConfig?, |
| 116 | + ): GenerateImageRequest { |
| 117 | + return GenerateImageRequest( |
| 118 | + listOf(ImagenPromptInstance(prompt)), |
| 119 | + ImagenParameters( |
| 120 | + sampleCount = config?.numberOfImages ?: 1, |
| 121 | + includeRaiReason = true, |
| 122 | + addWatermark = generationConfig?.addWatermark, |
| 123 | + personGeneration = safetySettings?.personFilterLevel?.internalVal, |
| 124 | + negativePrompt = config?.negativePrompt, |
| 125 | + safetySetting = safetySettings?.safetyFilterLevel?.internalVal, |
| 126 | + storageUri = gcsUri, |
| 127 | + aspectRatio = config?.aspectRatio?.internalVal, |
| 128 | + imageOutputOptions = generationConfig?.imageFormat?.toInternal(), |
| 129 | + ), |
| 130 | + ) |
| 131 | + } |
| 132 | + |
| 133 | + internal companion object { |
| 134 | + private val TAG = ImageModel::class.java.simpleName |
| 135 | + internal const val DEFAULT_FILTERED_ERROR = "Unable to show generated images. All images were filtered out because they violated Vertex AI's usage guidelines. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback." |
| 136 | + } |
| 137 | +} |
| 138 | + |
| 139 | +private fun <T : ImagenImageRepresentible> ImagenGenerationResponse<T>.validate(): |
| 140 | + ImagenGenerationResponse<T> { |
| 141 | + if (images.isEmpty()) { |
| 142 | + throw PromptBlockedException(message = filteredReason ?: ImageModel.DEFAULT_FILTERED_ERROR) |
| 143 | + } |
| 144 | + return this |
| 145 | +} |
0 commit comments