Skip to content

Commit f826190

Browse files
author
David Motsonashvili
committed
Initial implementation to API spec
1 parent 11403df commit f826190

21 files changed

+380
-15
lines changed

firebase-vertexai/gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
version=16.1.0
15+
version=17.0.0
1616
latestReleasedVersion=16.0.2

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/FirebaseVertexAI.kt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import com.google.firebase.auth.internal.InternalAuthProvider
2424
import com.google.firebase.inject.Provider
2525
import com.google.firebase.vertexai.type.Content
2626
import com.google.firebase.vertexai.type.GenerationConfig
27+
import com.google.firebase.vertexai.type.ImageSafetySettings
28+
import com.google.firebase.vertexai.type.ImagenModelConfig
2729
import com.google.firebase.vertexai.type.InvalidLocationException
2830
import com.google.firebase.vertexai.type.RequestOptions
2931
import com.google.firebase.vertexai.type.SafetySetting
@@ -79,6 +81,27 @@ internal constructor(
7981
)
8082
}
8183

84+
@JvmOverloads
85+
public fun imageModel(
86+
modelName: String,
87+
generationConfig: ImagenModelConfig? = null,
88+
safetySettings: ImageSafetySettings? = null,
89+
requestOptions: RequestOptions = RequestOptions(),
90+
): ImageModel {
91+
if (location.trim().isEmpty() || location.contains("/")) {
92+
throw InvalidLocationException(location)
93+
}
94+
return ImageModel(
95+
"projects/${firebaseApp.options.projectId}/locations/${location}/publishers/google/models/${modelName}",
96+
firebaseApp.options.apiKey,
97+
generationConfig,
98+
safetySettings,
99+
requestOptions,
100+
appCheckProvider.get(),
101+
internalAuthProvider.get(),
102+
)
103+
}
104+
82105
public companion object {
83106
/** The [FirebaseVertexAI] instance for the default [FirebaseApp] */
84107
@JvmStatic
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
package com.google.firebase.vertexai
2+
3+
import android.util.Log
4+
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
5+
import com.google.firebase.auth.internal.InternalAuthProvider
6+
import com.google.firebase.vertexai.common.APIController
7+
import com.google.firebase.vertexai.common.HeaderProvider
8+
import com.google.firebase.vertexai.internal.GenerateImageRequest
9+
import com.google.firebase.vertexai.internal.ImagenParameters
10+
import com.google.firebase.vertexai.internal.ImagenPromptInstance
11+
import com.google.firebase.vertexai.internal.util.toInternal
12+
import com.google.firebase.vertexai.internal.util.toPublicGCS
13+
import com.google.firebase.vertexai.internal.util.toPublicInline
14+
import com.google.firebase.vertexai.type.FirebaseVertexAIException
15+
import com.google.firebase.vertexai.type.ImageSafetySettings
16+
import com.google.firebase.vertexai.type.ImagenGCSImage
17+
import com.google.firebase.vertexai.type.ImagenGenerationConfig
18+
import com.google.firebase.vertexai.type.ImagenGenerationResponse
19+
import com.google.firebase.vertexai.type.ImagenImageRepresentible
20+
import com.google.firebase.vertexai.type.ImagenInlineImage
21+
import com.google.firebase.vertexai.type.ImagenModelConfig
22+
import com.google.firebase.vertexai.type.PromptBlockedException
23+
import com.google.firebase.vertexai.type.RequestOptions
24+
import kotlinx.coroutines.tasks.await
25+
import kotlin.time.Duration
26+
import kotlin.time.Duration.Companion.seconds
27+
28+
public class ImageModel
29+
internal constructor(
30+
private val modelName: String,
31+
private val generationConfig: ImagenModelConfig? = null,
32+
private val safetySettings: ImageSafetySettings? = null,
33+
private val controller: APIController,
34+
) {
35+
@JvmOverloads
36+
internal constructor(
37+
modelName: String,
38+
apiKey: String,
39+
generationConfig: ImagenModelConfig? = null,
40+
safetySettings: ImageSafetySettings? = null,
41+
requestOptions: RequestOptions = RequestOptions(),
42+
appCheckTokenProvider: InteropAppCheckTokenProvider? = null,
43+
internalAuthProvider: InternalAuthProvider? = null,
44+
) : this(
45+
modelName,
46+
generationConfig,
47+
safetySettings,
48+
APIController(
49+
apiKey,
50+
modelName,
51+
requestOptions,
52+
"gl-kotlin/${KotlinVersion.CURRENT} fire/${BuildConfig.VERSION_NAME}",
53+
object : HeaderProvider {
54+
override val timeout: Duration
55+
get() = 10.seconds
56+
57+
override suspend fun generateHeaders(): Map<String, String> {
58+
val headers = mutableMapOf<String, String>()
59+
if (appCheckTokenProvider == null) {
60+
Log.w(TAG, "AppCheck not registered, skipping")
61+
} else {
62+
val token = appCheckTokenProvider.getToken(false).await()
63+
64+
if (token.error != null) {
65+
Log.w(TAG, "Error obtaining AppCheck token", token.error)
66+
}
67+
// The Firebase App Check backend can differentiate between apps without App Check, and
68+
// wrongly configured apps by verifying the value of the token, so it always needs to be
69+
// included.
70+
headers["X-Firebase-AppCheck"] = token.token
71+
}
72+
73+
if (internalAuthProvider == null) {
74+
Log.w(TAG, "Auth not registered, skipping")
75+
} else {
76+
try {
77+
val token = internalAuthProvider.getAccessToken(false).await()
78+
79+
headers["Authorization"] = "Firebase ${token.token!!}"
80+
} catch (e: Exception) {
81+
Log.w(TAG, "Error getting Auth token ", e)
82+
}
83+
}
84+
85+
return headers
86+
}
87+
},
88+
),
89+
)
90+
91+
public suspend fun generateImage(
92+
prompt: String,
93+
config: ImagenGenerationConfig?,
94+
): ImagenGenerationResponse<ImagenInlineImage> =
95+
try {
96+
controller.generateImage(constructRequest(prompt, null, config)).toPublicInline().validate()
97+
} catch (e: Throwable) {
98+
throw FirebaseVertexAIException.from(e)
99+
}
100+
101+
public suspend fun generateImage(
102+
prompt: String,
103+
gcsUri: String,
104+
config: ImagenGenerationConfig?,
105+
): ImagenGenerationResponse<ImagenGCSImage> =
106+
try {
107+
controller.generateImage(constructRequest(prompt, gcsUri, config)).toPublicGCS().validate()
108+
} catch (e: Throwable) {
109+
throw FirebaseVertexAIException.from(e)
110+
}
111+
112+
private fun constructRequest(
113+
prompt: String,
114+
gcsUri: String?,
115+
config: ImagenGenerationConfig?,
116+
): GenerateImageRequest {
117+
return GenerateImageRequest(
118+
listOf(ImagenPromptInstance(prompt)),
119+
ImagenParameters(
120+
sampleCount = config?.numberOfImages ?: 1,
121+
includeRaiReason = true,
122+
addWatermark = generationConfig?.addWatermark,
123+
personGeneration = safetySettings?.personFilterLevel?.internalVal,
124+
negativePrompt = config?.negativePrompt,
125+
safetySetting = safetySettings?.safetyFilterLevel?.internalVal,
126+
storageUri = gcsUri,
127+
aspectRatio = config?.aspectRatio?.internalVal,
128+
imageOutputOptions = generationConfig?.imageFormat?.toInternal(),
129+
),
130+
)
131+
}
132+
133+
internal companion object {
134+
private val TAG = ImageModel::class.java.simpleName
135+
internal const val DEFAULT_FILTERED_ERROR = "Unable to show generated images. All images were filtered out because they violated Vertex AI's usage guidelines. You will not be charged for blocked images. Try rephrasing the prompt. If you think this was an error, send feedback."
136+
}
137+
}
138+
139+
private fun <T : ImagenImageRepresentible> ImagenGenerationResponse<T>.validate():
140+
ImagenGenerationResponse<T> {
141+
if (images.isEmpty()) {
142+
throw PromptBlockedException(message = filteredReason ?: ImageModel.DEFAULT_FILTERED_ERROR)
143+
}
144+
return this
145+
}

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/APIController.kt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import com.google.firebase.vertexai.common.server.GRpcError
2424
import com.google.firebase.vertexai.common.server.GRpcErrorDetails
2525
import com.google.firebase.vertexai.common.util.decodeToFlow
2626
import com.google.firebase.vertexai.common.util.fullModelName
27+
import com.google.firebase.vertexai.internal.GenerateImageRequest
28+
import com.google.firebase.vertexai.internal.GenerateImageResponse
2729
import com.google.firebase.vertexai.type.RequestOptions
2830
import io.ktor.client.HttpClient
2931
import io.ktor.client.call.body
@@ -120,6 +122,20 @@ internal constructor(
120122
throw FirebaseCommonAIException.from(e)
121123
}
122124

125+
suspend fun generateImage(request: GenerateImageRequest): GenerateImageResponse =
126+
try {
127+
client
128+
.post("${requestOptions.endpoint}/${requestOptions.apiVersion}/$model:predict") {
129+
applyCommonConfiguration(request)
130+
applyHeaderProvider()
131+
}
132+
.also { validateResponse(it) }
133+
.body<GenerateImageResponse>()
134+
.validate()
135+
} catch (e: Throwable) {
136+
throw FirebaseCommonAIException.from(e)
137+
}
138+
123139
fun generateContentStream(request: GenerateContentRequest): Flow<GenerateContentResponse> =
124140
client
125141
.postStream<GenerateContentResponse>(
@@ -147,6 +163,7 @@ internal constructor(
147163
when (request) {
148164
is GenerateContentRequest -> setBody<GenerateContentRequest>(request)
149165
is CountTokensRequest -> setBody<CountTokensRequest>(request)
166+
is GenerateImageRequest -> setBody<GenerateImageRequest>(request)
150167
}
151168
contentType(ContentType.Application.Json)
152169
header("x-goog-api-key", key)
@@ -291,3 +308,7 @@ private fun GenerateContentResponse.validate() = apply {
291308
?.firstOrNull { it != FinishReason.STOP }
292309
?.let { throw ResponseStoppedException(this) }
293310
}
311+
312+
private fun GenerateImageResponse.validate() = apply {
313+
314+
}

firebase-vertexai/src/main/kotlin/com/google/firebase/vertexai/common/Request.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import com.google.firebase.vertexai.common.util.fullModelName
2525
import kotlinx.serialization.SerialName
2626
import kotlinx.serialization.Serializable
2727

28-
internal sealed interface Request
28+
internal interface Request
2929

3030
@Serializable
3131
internal data class GenerateContentRequest(
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package com.google.firebase.vertexai.internal
2+
3+
import com.google.firebase.vertexai.common.Request
4+
import kotlinx.serialization.Serializable
5+
6+
@Serializable
7+
internal data class GenerateImageRequest(
8+
val instances: List<ImagenPromptInstance>,
9+
val parameters: ImagenParameters,
10+
) : Request {}
11+
12+
@Serializable internal data class ImagenPromptInstance(val prompt: String)
13+
14+
@Serializable
15+
internal data class ImagenParameters(
16+
val sampleCount: Int = 1,
17+
val includeRaiReason: Boolean = true,
18+
val storageUri: String?,
19+
val negativePrompt: String?,
20+
val aspectRatio: String?,
21+
val safetySetting: String?,
22+
val personGeneration: String?,
23+
val addWatermark: Boolean?,
24+
val imageOutputOptions: ImageOutputOptions?,
25+
)
26+
27+
@Serializable
28+
internal data class ImageOutputOptions(val mimeType: String, val compressionQuality: Int?)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package com.google.firebase.vertexai.internal
2+
3+
import kotlinx.serialization.Serializable
4+
5+
@Serializable
6+
internal data class GenerateImageResponse(val predictions: List<ImagenImageResponse>) {}
7+
8+
@Serializable
9+
internal data class ImagenImageResponse(
10+
val bytesBase64Encoded: String?,
11+
val gcsUri: String?,
12+
val mimeType: String,
13+
)

0 commit comments

Comments
 (0)