25
25
import typing
26
26
from typing import Any , Dict , List , Literal , Optional , Union
27
27
28
+ from google .generativeai import client
29
+ from google .generativeai import protos
30
+
28
31
from google .protobuf import struct_pb2
29
32
30
33
from proto .marshal .collections import maps
43
46
PIL_Image = None
44
47
45
48
49
+ # This is to get around https://github.com/googleapis/proto-plus-python/issues/488
46
50
def to_value (value ) -> struct_pb2 .Value :
47
51
"""Return a protobuf Value object representing this value."""
48
52
if isinstance (value , struct_pb2 .Value ):
@@ -61,6 +65,7 @@ def to_value(value) -> struct_pb2.Value:
61
65
return struct_pb2 .Value (struct_value = to_mapping_value (value ))
62
66
raise ValueError ("Unable to coerce value: %r" % value )
63
67
68
+
64
69
def to_list_value (value ) -> struct_pb2 .ListValue :
65
70
# We got a proto, or else something we sent originally.
66
71
# Preserve the instance we have.
@@ -70,9 +75,8 @@ def to_list_value(value) -> struct_pb2.ListValue:
70
75
return struct_pb2 .ListValue (values = [v for v in value .pb ])
71
76
72
77
# We got a list (or something list-like); convert it.
73
- return struct_pb2 .ListValue (
74
- values = [to_value (v ) for v in value ]
75
- )
78
+ return struct_pb2 .ListValue (values = [to_value (v ) for v in value ])
79
+
76
80
77
81
def to_mapping_value (value ) -> struct_pb2 .Struct :
78
82
# We got a proto, or else something we sent originally.
@@ -85,12 +89,7 @@ def to_mapping_value(value) -> struct_pb2.Struct:
85
89
)
86
90
87
91
# We got a dict (or something dict-like); convert it.
88
- return struct_pb2 .Struct (
89
- fields = {
90
- k : to_value (v ) for k , v in value .items ()
91
- }
92
- )
93
-
92
+ return struct_pb2 .Struct (fields = {k : to_value (v ) for k , v in value .items ()})
94
93
95
94
96
95
_SUPPORTED_UPSCALING_SIZES = [2048 , 4096 ]
@@ -131,7 +130,6 @@ def load_from_file(location: str) -> "Image":
131
130
image = Image (image_bytes = image_bytes )
132
131
return image
133
132
134
-
135
133
@property
136
134
def _image_bytes (self ) -> bytes :
137
135
return self ._loaded_bytes
@@ -206,9 +204,16 @@ class ImageGenerationModel:
206
204
response[0].save("image1.png")
207
205
"""
208
206
209
- __module__ = "vertexai.preview.vision_models"
207
+ def __init__ (self , model_id : str ):
208
+ if not model_id .startswith ("models" ):
209
+ model_id = f"models/{ model_id } "
210
+ self .model_name = model_id
211
+ self ._client = None
210
212
211
- _INSTANCE_SCHEMA_URI = "gs://google-cloud-aiplatform/schema/predict/instance/vision_generative_model_1.0.0.yaml"
213
+ @classmethod
214
+ def from_pretrained (cls , model_name : str ):
215
+ """For vertex compatibility"""
216
+ return cls (model_name )
212
217
213
218
def _generate_images (
214
219
self ,
@@ -242,9 +247,7 @@ def _generate_images(
242
247
safety_filter_level : Optional [
243
248
Literal ["block_most" , "block_some" , "block_few" , "block_fewest" ]
244
249
] = None ,
245
- person_generation : Optional [
246
- Literal ["dont_allow" , "allow_adult" , "allow_all" ]
247
- ] = None ,
250
+ person_generation : Optional [Literal ["dont_allow" , "allow_adult" , "allow_all" ]] = None ,
248
251
) -> "ImageGenerationResponse" :
249
252
"""Generates images from text prompt.
250
253
@@ -312,6 +315,8 @@ class ID
312
315
Returns:
313
316
An `ImageGenerationResponse` object.
314
317
"""
318
+ if self ._client is None :
319
+ self ._client = client .get_default_prediction_client ()
315
320
# Note: Only a single prompt is supported by the service.
316
321
instance = {"prompt" : prompt }
317
322
shared_generation_parameters = {
@@ -412,11 +417,14 @@ class ID
412
417
parameters ["personGeneration" ] = person_generation
413
418
shared_generation_parameters ["person_generation" ] = person_generation
414
419
415
- response = self ._endpoint .predict (
416
- instances = [to_value (instance )],
417
- parameters = parameters ,
420
+ # This is to get around https://github.com/googleapis/proto-plus-python/issues/488
421
+ pr = protos .PredictRequest .pb ()
422
+ request = pr (
423
+ model = self .model_name , instances = [to_value (instance )], parameters = to_value (parameters )
418
424
)
419
425
426
+ response = self ._client .predict (request )
427
+
420
428
generated_images : List ["GeneratedImage" ] = []
421
429
for idx , prediction in enumerate (response .predictions ):
422
430
generation_parameters = dict (shared_generation_parameters )
@@ -444,9 +452,7 @@ def generate_images(
444
452
safety_filter_level : Optional [
445
453
Literal ["block_most" , "block_some" , "block_few" , "block_fewest" ]
446
454
] = None ,
447
- person_generation : Optional [
448
- Literal ["dont_allow" , "allow_adult" , "allow_all" ]
449
- ] = None ,
455
+ person_generation : Optional [Literal ["dont_allow" , "allow_adult" , "allow_all" ]] = None ,
450
456
) -> "ImageGenerationResponse" :
451
457
"""Generates images from text prompt.
452
458
@@ -510,9 +516,7 @@ def edit_image(
510
516
number_of_images : int = 1 ,
511
517
guidance_scale : Optional [float ] = None ,
512
518
edit_mode : Optional [
513
- Literal [
514
- "inpainting-insert" , "inpainting-remove" , "outpainting" , "product-image"
515
- ]
519
+ Literal ["inpainting-insert" , "inpainting-remove" , "outpainting" , "product-image" ]
516
520
] = None ,
517
521
mask_mode : Optional [Literal ["background" , "foreground" , "semantic" ]] = None ,
518
522
segmentation_classes : Optional [List [str ]] = None ,
@@ -525,9 +529,7 @@ def edit_image(
525
529
safety_filter_level : Optional [
526
530
Literal ["block_most" , "block_some" , "block_few" , "block_fewest" ]
527
531
] = None ,
528
- person_generation : Optional [
529
- Literal ["dont_allow" , "allow_adult" , "allow_all" ]
530
- ] = None ,
532
+ person_generation : Optional [Literal ["dont_allow" , "allow_adult" , "allow_all" ]] = None ,
531
533
) -> "ImageGenerationResponse" :
532
534
"""Edits an existing image based on text prompt.
533
535
@@ -717,9 +719,7 @@ def upscale_image(
717
719
718
720
parameters ["outputOptions" ] = {"mimeType" : output_mime_type }
719
721
if output_mime_type == "image/jpeg" and output_compression_quality is not None :
720
- parameters ["outputOptions" ][
721
- "compressionQuality"
722
- ] = output_compression_quality
722
+ parameters ["outputOptions" ]["compressionQuality" ] = output_compression_quality
723
723
724
724
response = self ._endpoint .predict (
725
725
instances = [to_value (instance )],
@@ -825,9 +825,7 @@ def save(self, location: str, include_generation_parameters: bool = True):
825
825
if not self ._generation_parameters :
826
826
raise ValueError ("Image does not have generation parameters." )
827
827
if not PIL_Image :
828
- raise ValueError (
829
- "The PIL module is required for saving generation parameters."
830
- )
828
+ raise ValueError ("The PIL module is required for saving generation parameters." )
831
829
832
830
exif = self ._pil_image .getexif ()
833
831
exif [_EXIF_USER_COMMENT_TAG_IDX ] = json .dumps (
@@ -836,4 +834,3 @@ def save(self, location: str, include_generation_parameters: bool = True):
836
834
self ._pil_image .save (location , exif = exif )
837
835
else :
838
836
super ().save (location = location )
839
-
0 commit comments