Skip to content

Commit e21e918

Browse files
committed
Basically works.
Change-Id: I28364ab70b2a263b29026f2cf2d1d4f807d88f53
1 parent 59642c2 commit e21e918

File tree

3 files changed

+41
-34
lines changed

3 files changed

+41
-34
lines changed

google/generativeai/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
from google.generativeai.generative_models import GenerativeModel
6060
from google.generativeai.generative_models import ChatSession
6161

62+
from google.generativeai.vision_models import *
63+
6264
from google.generativeai.models import list_models
6365
from google.generativeai.models import list_tuned_models
6466

google/generativeai/client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,11 @@ def get_default_permission_client() -> glm.PermissionServiceClient:
377377

378378
def get_default_permission_async_client() -> glm.PermissionServiceAsyncClient:
379379
return _client_manager.get_default_client("permission_async")
380+
381+
382+
def get_default_prediction_client() -> glm.PermissionServiceClient:
383+
return _client_manager.get_default_client("prediction")
384+
385+
386+
def get_default_prediction_async_client() -> glm.PermissionServiceAsyncClient:
387+
return _client_manager.get_default_client("prediction_async")

google/generativeai/vision_models/_vision_models.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
import typing
2626
from typing import Any, Dict, List, Literal, Optional, Union
2727

28+
from google.generativeai import client
29+
from google.generativeai import protos
30+
2831
from google.protobuf import struct_pb2
2932

3033
from proto.marshal.collections import maps
@@ -43,6 +46,7 @@
4346
PIL_Image = None
4447

4548

49+
# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
4650
def to_value(value) -> struct_pb2.Value:
4751
"""Return a protobuf Value object representing this value."""
4852
if isinstance(value, struct_pb2.Value):
@@ -61,6 +65,7 @@ def to_value(value) -> struct_pb2.Value:
6165
return struct_pb2.Value(struct_value=to_mapping_value(value))
6266
raise ValueError("Unable to coerce value: %r" % value)
6367

68+
6469
def to_list_value(value) -> struct_pb2.ListValue:
6570
# We got a proto, or else something we sent originally.
6671
# Preserve the instance we have.
@@ -70,9 +75,8 @@ def to_list_value(value) -> struct_pb2.ListValue:
7075
return struct_pb2.ListValue(values=[v for v in value.pb])
7176

7277
# We got a list (or something list-like); convert it.
73-
return struct_pb2.ListValue(
74-
values=[to_value(v) for v in value]
75-
)
78+
return struct_pb2.ListValue(values=[to_value(v) for v in value])
79+
7680

7781
def to_mapping_value(value) -> struct_pb2.Struct:
7882
# We got a proto, or else something we sent originally.
@@ -85,12 +89,7 @@ def to_mapping_value(value) -> struct_pb2.Struct:
8589
)
8690

8791
# We got a dict (or something dict-like); convert it.
88-
return struct_pb2.Struct(
89-
fields={
90-
k: to_value(v) for k, v in value.items()
91-
}
92-
)
93-
92+
return struct_pb2.Struct(fields={k: to_value(v) for k, v in value.items()})
9493

9594

9695
_SUPPORTED_UPSCALING_SIZES = [2048, 4096]
@@ -131,7 +130,6 @@ def load_from_file(location: str) -> "Image":
131130
image = Image(image_bytes=image_bytes)
132131
return image
133132

134-
135133
@property
136134
def _image_bytes(self) -> bytes:
137135
return self._loaded_bytes
@@ -206,9 +204,16 @@ class ImageGenerationModel:
206204
response[0].save("image1.png")
207205
"""
208206

209-
__module__ = "vertexai.preview.vision_models"
207+
def __init__(self, model_id: str):
208+
if not model_id.startswith("models"):
209+
model_id = f"models/{model_id}"
210+
self.model_name = model_id
211+
self._client = None
210212

211-
_INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/vision_generative_model_1.0.0.yaml"
213+
@classmethod
214+
def from_pretrained(cls, model_name: str):
215+
"""For vertex compatibility"""
216+
return cls(model_name)
212217

213218
def _generate_images(
214219
self,
@@ -242,9 +247,7 @@ def _generate_images(
242247
safety_filter_level: Optional[
243248
Literal["block_most", "block_some", "block_few", "block_fewest"]
244249
] = None,
245-
person_generation: Optional[
246-
Literal["dont_allow", "allow_adult", "allow_all"]
247-
] = None,
250+
person_generation: Optional[Literal["dont_allow", "allow_adult", "allow_all"]] = None,
248251
) -> "ImageGenerationResponse":
249252
"""Generates images from text prompt.
250253
@@ -312,6 +315,8 @@ class ID
312315
Returns:
313316
An `ImageGenerationResponse` object.
314317
"""
318+
if self._client is None:
319+
self._client = client.get_default_prediction_client()
315320
# Note: Only a single prompt is supported by the service.
316321
instance = {"prompt": prompt}
317322
shared_generation_parameters = {
@@ -412,11 +417,14 @@ class ID
412417
parameters["personGeneration"] = person_generation
413418
shared_generation_parameters["person_generation"] = person_generation
414419

415-
response = self._endpoint.predict(
416-
instances=[to_value(instance)],
417-
parameters=parameters,
420+
# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
421+
pr = protos.PredictRequest.pb()
422+
request = pr(
423+
model=self.model_name, instances=[to_value(instance)], parameters=to_value(parameters)
418424
)
419425

426+
response = self._client.predict(request)
427+
420428
generated_images: List["GeneratedImage"] = []
421429
for idx, prediction in enumerate(response.predictions):
422430
generation_parameters = dict(shared_generation_parameters)
@@ -444,9 +452,7 @@ def generate_images(
444452
safety_filter_level: Optional[
445453
Literal["block_most", "block_some", "block_few", "block_fewest"]
446454
] = None,
447-
person_generation: Optional[
448-
Literal["dont_allow", "allow_adult", "allow_all"]
449-
] = None,
455+
person_generation: Optional[Literal["dont_allow", "allow_adult", "allow_all"]] = None,
450456
) -> "ImageGenerationResponse":
451457
"""Generates images from text prompt.
452458
@@ -510,9 +516,7 @@ def edit_image(
510516
number_of_images: int = 1,
511517
guidance_scale: Optional[float] = None,
512518
edit_mode: Optional[
513-
Literal[
514-
"inpainting-insert", "inpainting-remove", "outpainting", "product-image"
515-
]
519+
Literal["inpainting-insert", "inpainting-remove", "outpainting", "product-image"]
516520
] = None,
517521
mask_mode: Optional[Literal["background", "foreground", "semantic"]] = None,
518522
segmentation_classes: Optional[List[str]] = None,
@@ -525,9 +529,7 @@ def edit_image(
525529
safety_filter_level: Optional[
526530
Literal["block_most", "block_some", "block_few", "block_fewest"]
527531
] = None,
528-
person_generation: Optional[
529-
Literal["dont_allow", "allow_adult", "allow_all"]
530-
] = None,
532+
person_generation: Optional[Literal["dont_allow", "allow_adult", "allow_all"]] = None,
531533
) -> "ImageGenerationResponse":
532534
"""Edits an existing image based on text prompt.
533535
@@ -717,9 +719,7 @@ def upscale_image(
717719

718720
parameters["outputOptions"] = {"mimeType": output_mime_type}
719721
if output_mime_type == "image/jpeg" and output_compression_quality is not None:
720-
parameters["outputOptions"][
721-
"compressionQuality"
722-
] = output_compression_quality
722+
parameters["outputOptions"]["compressionQuality"] = output_compression_quality
723723

724724
response = self._endpoint.predict(
725725
instances=[to_value(instance)],
@@ -825,9 +825,7 @@ def save(self, location: str, include_generation_parameters: bool = True):
825825
if not self._generation_parameters:
826826
raise ValueError("Image does not have generation parameters.")
827827
if not PIL_Image:
828-
raise ValueError(
829-
"The PIL module is required for saving generation parameters."
830-
)
828+
raise ValueError("The PIL module is required for saving generation parameters.")
831829

832830
exif = self._pil_image.getexif()
833831
exif[_EXIF_USER_COMMENT_TAG_IDX] = json.dumps(
@@ -836,4 +834,3 @@ def save(self, location: str, include_generation_parameters: bool = True):
836834
self._pil_image.save(location, exif=exif)
837835
else:
838836
super().save(location=location)
839-

0 commit comments

Comments
 (0)