Skip to content

Commit 8b2f214

Browse files
committed
Allow the use of generated images as inputs
Change-Id: I0956fb78272a8a8af2c5219d80a26dec944040a8
1 parent e339835 commit 8b2f214

File tree

5 files changed

+108
-406
lines changed

5 files changed

+108
-406
lines changed

google/generativeai/client.py

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import dataclasses
77
import pathlib
88
from typing import Any, cast
9-
from collections.abc import Sequence
9+
from collections.abc import Sequence, Mapping
1010
import httplib2
1111
from io import IOBase
1212

@@ -23,6 +23,11 @@
2323
import googleapiclient.http
2424
import googleapiclient.discovery
2525

26+
from google.protobuf import struct_pb2
27+
28+
from proto.marshal.collections import maps
29+
from proto.marshal.collections import repeated
30+
2631
try:
2732
from google.generativeai import version
2833

@@ -130,6 +135,73 @@ async def create_file(self, *args, **kwargs):
130135
)
131136

132137

138+
# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
139+
def to_value(value) -> struct_pb2.Value:
140+
"""Return a protobuf Value object representing this value."""
141+
if isinstance(value, struct_pb2.Value):
142+
return value
143+
if value is None:
144+
return struct_pb2.Value(null_value=0)
145+
if isinstance(value, bool):
146+
return struct_pb2.Value(bool_value=value)
147+
if isinstance(value, (int, float)):
148+
return struct_pb2.Value(number_value=float(value))
149+
if isinstance(value, str):
150+
return struct_pb2.Value(string_value=value)
151+
if isinstance(value, collections.abc.Sequence):
152+
return struct_pb2.Value(list_value=to_list_value(value))
153+
if isinstance(value, collections.abc.Mapping):
154+
return struct_pb2.Value(struct_value=to_mapping_value(value))
155+
raise ValueError("Unable to coerce value: %r" % value)
156+
157+
158+
def to_list_value(value) -> struct_pb2.ListValue:
159+
# We got a proto, or else something we sent originally.
160+
# Preserve the instance we have.
161+
if isinstance(value, struct_pb2.ListValue):
162+
return value
163+
if isinstance(value, repeated.RepeatedComposite):
164+
return struct_pb2.ListValue(values=[v for v in value.pb])
165+
166+
# We got a list (or something list-like); convert it.
167+
return struct_pb2.ListValue(values=[to_value(v) for v in value])
168+
169+
170+
def to_mapping_value(value) -> struct_pb2.Struct:
171+
# We got a proto, or else something we sent originally.
172+
# Preserve the instance we have.
173+
if isinstance(value, struct_pb2.Struct):
174+
return value
175+
if isinstance(value, maps.MapComposite):
176+
return struct_pb2.Struct(
177+
fields={k: v for k, v in value.pb.items()},
178+
)
179+
180+
# We got a dict (or something dict-like); convert it.
181+
return struct_pb2.Struct(fields={k: to_value(v) for k, v in value.items()})
182+
183+
184+
# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
185+
186+
187+
class PredictionServiceClient(glm.PredictionServiceClient):
188+
def predict(self, model=None, instances=None, parameters=None):
189+
pr = protos.PredictRequest.pb()
190+
request = pr(
191+
model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters)
192+
)
193+
return super().predict(request)
194+
195+
196+
class PredictionServiceAsyncClient(glm.PredictionServiceAsyncClient):
197+
async def predict(self, model=None, instances=None, parameters=None):
198+
pr = protos.PredictRequest.pb()
199+
request = pr(
200+
model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters)
201+
)
202+
return await super().predict(request)
203+
204+
133205
@dataclasses.dataclass
134206
class _ClientManager:
135207
client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
@@ -220,15 +292,20 @@ def configure(
220292
self.clients = {}
221293

222294
def make_client(self, name):
223-
if name == "file":
224-
cls = FileServiceClient
225-
elif name == "file_async":
226-
cls = FileServiceAsyncClient
227-
elif name.endswith("_async"):
228-
name = name.split("_")[0]
229-
cls = getattr(glm, name.title() + "ServiceAsyncClient")
230-
else:
231-
cls = getattr(glm, name.title() + "ServiceClient")
295+
local_clients = {
296+
"file": FileServiceClient,
297+
"file_async": FileServiceAsyncClient,
298+
"prediction": PredictionServiceClient,
299+
"prediction_async": PredictionServiceAsyncClient,
300+
}
301+
cls = local_clients.get("name", None)
302+
303+
if cls is None:
304+
if name.endswith("_async"):
305+
name = name.split("_")[0]
306+
cls = getattr(glm, name.title() + "ServiceAsyncClient")
307+
else:
308+
cls = getattr(glm, name.title() + "ServiceClient")
232309

233310
# Attempt to configure using defaults.
234311
if not self.client_config:

google/generativeai/types/content_types.py

Lines changed: 4 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -16,45 +16,16 @@
1616
from __future__ import annotations
1717

1818
from collections.abc import Iterable, Mapping, Sequence
19-
import io
2019
import inspect
21-
import mimetypes
22-
import pathlib
23-
import typing
2420
from typing import Any, Callable, Union
2521
from typing_extensions import TypedDict
2622

2723
import pydantic
2824

2925
from google.generativeai.types import file_types
26+
from google.generativeai.types.image_types import _image_types
3027
from google.generativeai import protos
3128

32-
if typing.TYPE_CHECKING:
33-
import PIL.Image
34-
import PIL.ImageFile
35-
import IPython.display
36-
37-
IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
38-
ImageType = PIL.Image.Image | IPython.display.Image
39-
else:
40-
IMAGE_TYPES = ()
41-
try:
42-
import PIL.Image
43-
import PIL.ImageFile
44-
45-
IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,)
46-
except ImportError:
47-
PIL = None
48-
49-
try:
50-
import IPython.display
51-
52-
IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,)
53-
except ImportError:
54-
IPython = None
55-
56-
ImageType = Union["PIL.Image.Image", "IPython.display.Image"]
57-
5829

5930
__all__ = [
6031
"BlobDict",
@@ -97,62 +68,6 @@ def to_mode(x: ModeOptions) -> Mode:
9768
return _MODE[x]
9869

9970

100-
def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
101-
# If the image is a local file, return a file-based blob without any modification.
102-
# Otherwise, return a lossless WebP blob (same quality with optimized size).
103-
def file_blob(image: PIL.Image.Image) -> protos.Blob | None:
104-
if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None:
105-
return None
106-
filename = str(image.filename)
107-
if not pathlib.Path(filename).is_file():
108-
return None
109-
110-
mime_type = image.get_format_mimetype()
111-
image_bytes = pathlib.Path(filename).read_bytes()
112-
113-
return protos.Blob(mime_type=mime_type, data=image_bytes)
114-
115-
def webp_blob(image: PIL.Image.Image) -> protos.Blob:
116-
# Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
117-
image_io = io.BytesIO()
118-
image.save(image_io, format="webp", lossless=True)
119-
image_io.seek(0)
120-
121-
mime_type = "image/webp"
122-
image_bytes = image_io.read()
123-
124-
return protos.Blob(mime_type=mime_type, data=image_bytes)
125-
126-
return file_blob(image) or webp_blob(image)
127-
128-
129-
def image_to_blob(image: ImageType) -> protos.Blob:
130-
if PIL is not None:
131-
if isinstance(image, PIL.Image.Image):
132-
return _pil_to_blob(image)
133-
134-
if IPython is not None:
135-
if isinstance(image, IPython.display.Image):
136-
name = image.filename
137-
if name is None:
138-
raise ValueError(
139-
"Conversion failed. The `IPython.display.Image` can only be converted if "
140-
"it is constructed from a local file. Please ensure you are using the format: Image(filename='...')."
141-
)
142-
mime_type, _ = mimetypes.guess_type(name)
143-
if mime_type is None:
144-
mime_type = "image/unknown"
145-
146-
return protos.Blob(mime_type=mime_type, data=image.data)
147-
148-
raise TypeError(
149-
"Image conversion failed. The input was expected to be of type `Image` "
150-
"(either `PIL.Image.Image` or `IPython.display.Image`).\n"
151-
f"However, received an object of type: {type(image)}.\n"
152-
f"Object Value: {image}"
153-
)
154-
155-
15671
class BlobDict(TypedDict):
15772
mime_type: str
15873
data: bytes
@@ -189,12 +104,7 @@ def is_blob_dict(d):
189104
return "mime_type" in d and "data" in d
190105

191106

192-
if typing.TYPE_CHECKING:
193-
BlobType = Union[
194-
protos.Blob, BlobDict, PIL.Image.Image, IPython.display.Image
195-
] # Any for the images
196-
else:
197-
BlobType = Union[protos.Blob, BlobDict, Any]
107+
BlobType = Union[protos.Blob, BlobDict, _image_types.ImageType] # Any for the images
198108

199109

200110
def to_blob(blob: BlobType) -> protos.Blob:
@@ -203,8 +113,8 @@ def to_blob(blob: BlobType) -> protos.Blob:
203113

204114
if isinstance(blob, protos.Blob):
205115
return blob
206-
elif isinstance(blob, IMAGE_TYPES):
207-
return image_to_blob(blob)
116+
elif isinstance(blob, _image_types.IMAGE_TYPES):
117+
return _image_types.image_to_blob(blob)
208118
else:
209119
if isinstance(blob, Mapping):
210120
raise KeyError(

google/generativeai/vision_models/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
#
1515
"""Classes for working with vision models."""
1616

17+
from google.generativeai.types.image_types import check_watermark, Image, GeneratedImage
18+
1719
from google.generativeai.vision_models._vision_models import (
18-
check_watermark,
19-
Image,
20-
GeneratedImage,
2120
ImageGenerationModel,
2221
ImageGenerationResponse,
2322
)
2423

2524
__all__ = [
25+
"check_watermark",
2626
"Image",
2727
"GeneratedImage",
2828
"ImageGenerationModel",

0 commit comments

Comments
 (0)