Skip to content

Commit c72aa6b

Browse files
authored
Image feedback (#608)
* Allow the use of generated images as inputs Change-Id: I0956fb78272a8a8af2c5219d80a26dec944040a8 # Conflicts: # google/generativeai/vision_models/_vision_models.py * types + formatting Change-Id: I0cac4ba1de764d3c02c5eab7556d8324aeda1f93 * add files Change-Id: Ie7f91cef171c1f813b52ff1b2a4daedf7ea19edd * Fix 3.9 Change-Id: If9ff9ebc0b2bf16b91e741d862a9e2808c7a738a * Fix 3.9 Change-Id: Iee02352ca21fa66da9b097d4dfa9454b67609e79 * fix pytype Change-Id: Ic5c250f3f3ded2374abfbdbee6d62ea4cfb0f799 * fix pytype Change-Id: I431c66e45e7582218b5de7a90eeeee01b80df664 * typo Change-Id: I1bb15e1363c652f9c0b4a60dad834fce65a4f0a1 * reapply commits lost in merge Change-Id: I7bfebdeaa217d93ed5d11aca31cf0b20afd38c02 * Update google/generativeai/client.py * Remove GCS reference Change-Id: I5c1b8cbccee0e13d8aca70582a76e0c089e040ed * black . Change-Id: I2c24f8798cb8103d35474e7e6d2e4fc3100825aa
1 parent aae0caf commit c72aa6b

File tree

7 files changed

+443
-437
lines changed

7 files changed

+443
-437
lines changed

google/generativeai/client.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import os
44
import contextlib
55
import inspect
6+
import collections
67
import dataclasses
78
import pathlib
89
from typing import Any, cast
9-
from collections.abc import Sequence
10+
from collections.abc import Sequence, Mapping
1011
import httplib2
1112
from io import IOBase
1213

@@ -23,6 +24,11 @@
2324
import googleapiclient.http
2425
import googleapiclient.discovery
2526

27+
from google.protobuf import struct_pb2
28+
29+
from proto.marshal.collections import maps
30+
from proto.marshal.collections import repeated
31+
2632
try:
2733
from google.generativeai import version
2834

@@ -130,6 +136,70 @@ async def create_file(self, *args, **kwargs):
130136
)
131137

132138

139+
# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
140+
def to_value(value) -> struct_pb2.Value:
141+
"""Return a protobuf Value object representing this value."""
142+
if isinstance(value, struct_pb2.Value):
143+
return value
144+
if value is None:
145+
return struct_pb2.Value(null_value=0)
146+
if isinstance(value, bool):
147+
return struct_pb2.Value(bool_value=value)
148+
if isinstance(value, (int, float)):
149+
return struct_pb2.Value(number_value=float(value))
150+
if isinstance(value, str):
151+
return struct_pb2.Value(string_value=value)
152+
if isinstance(value, collections.abc.Sequence):
153+
return struct_pb2.Value(list_value=to_list_value(value))
154+
if isinstance(value, collections.abc.Mapping):
155+
return struct_pb2.Value(struct_value=to_mapping_value(value))
156+
raise ValueError("Unable to coerce value: %r" % value)
157+
158+
159+
def to_list_value(value) -> struct_pb2.ListValue:
160+
# We got a proto, or else something we sent originally.
161+
# Preserve the instance we have.
162+
if isinstance(value, struct_pb2.ListValue):
163+
return value
164+
if isinstance(value, repeated.RepeatedComposite):
165+
return struct_pb2.ListValue(values=[v for v in value.pb])
166+
167+
# We got a list (or something list-like); convert it.
168+
return struct_pb2.ListValue(values=[to_value(v) for v in value])
169+
170+
171+
def to_mapping_value(value) -> struct_pb2.Struct:
172+
# We got a proto, or else something we sent originally.
173+
# Preserve the instance we have.
174+
if isinstance(value, struct_pb2.Struct):
175+
return value
176+
if isinstance(value, maps.MapComposite):
177+
return struct_pb2.Struct(
178+
fields={k: v for k, v in value.pb.items()},
179+
)
180+
181+
# We got a dict (or something dict-like); convert it.
182+
return struct_pb2.Struct(fields={k: to_value(v) for k, v in value.items()})
183+
184+
185+
class PredictionServiceClient(glm.PredictionServiceClient):
186+
def predict(self, model=None, instances=None, parameters=None):
187+
pr = protos.PredictRequest.pb()
188+
request = pr(
189+
model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters)
190+
)
191+
return super().predict(request)
192+
193+
194+
class PredictionServiceAsyncClient(glm.PredictionServiceAsyncClient):
195+
async def predict(self, model=None, instances=None, parameters=None):
196+
pr = protos.PredictRequest.pb()
197+
request = pr(
198+
model=model, instances=[to_value(i) for i in instances], parameters=to_value(parameters)
199+
)
200+
return await super().predict(request)
201+
202+
133203
@dataclasses.dataclass
134204
class _ClientManager:
135205
client_config: dict[str, Any] = dataclasses.field(default_factory=dict)
@@ -220,15 +290,20 @@ def configure(
220290
self.clients = {}
221291

222292
def make_client(self, name):
223-
if name == "file":
224-
cls = FileServiceClient
225-
elif name == "file_async":
226-
cls = FileServiceAsyncClient
227-
elif name.endswith("_async"):
228-
name = name.split("_")[0]
229-
cls = getattr(glm, name.title() + "ServiceAsyncClient")
230-
else:
231-
cls = getattr(glm, name.title() + "ServiceClient")
293+
local_clients = {
294+
"file": FileServiceClient,
295+
"file_async": FileServiceAsyncClient,
296+
"prediction": PredictionServiceClient,
297+
"prediction_async": PredictionServiceAsyncClient,
298+
}
299+
cls = local_clients.get(name, None)
300+
301+
if cls is None:
302+
if name.endswith("_async"):
303+
name = name.split("_")[0]
304+
cls = getattr(glm, name.title() + "ServiceAsyncClient")
305+
else:
306+
cls = getattr(glm, name.title() + "ServiceClient")
232307

233308
# Attempt to configure using defaults.
234309
if not self.client_config:

google/generativeai/types/content_types.py

Lines changed: 4 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -16,45 +16,16 @@
1616
from __future__ import annotations
1717

1818
from collections.abc import Iterable, Mapping, Sequence
19-
import io
2019
import inspect
21-
import mimetypes
22-
import pathlib
23-
import typing
2420
from typing import Any, Callable, Union
2521
from typing_extensions import TypedDict
2622

2723
import pydantic
2824

2925
from google.generativeai.types import file_types
26+
from google.generativeai.types.image_types import _image_types
3027
from google.generativeai import protos
3128

32-
if typing.TYPE_CHECKING:
33-
import PIL.Image
34-
import PIL.ImageFile
35-
import IPython.display
36-
37-
IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
38-
ImageType = PIL.Image.Image | IPython.display.Image
39-
else:
40-
IMAGE_TYPES = ()
41-
try:
42-
import PIL.Image
43-
import PIL.ImageFile
44-
45-
IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,)
46-
except ImportError:
47-
PIL = None
48-
49-
try:
50-
import IPython.display
51-
52-
IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,)
53-
except ImportError:
54-
IPython = None
55-
56-
ImageType = Union["PIL.Image.Image", "IPython.display.Image"]
57-
5829

5930
__all__ = [
6031
"BlobDict",
@@ -97,62 +68,6 @@ def to_mode(x: ModeOptions) -> Mode:
9768
return _MODE[x]
9869

9970

100-
def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
101-
# If the image is a local file, return a file-based blob without any modification.
102-
# Otherwise, return a lossless WebP blob (same quality with optimized size).
103-
def file_blob(image: PIL.Image.Image) -> protos.Blob | None:
104-
if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None:
105-
return None
106-
filename = str(image.filename)
107-
if not pathlib.Path(filename).is_file():
108-
return None
109-
110-
mime_type = image.get_format_mimetype()
111-
image_bytes = pathlib.Path(filename).read_bytes()
112-
113-
return protos.Blob(mime_type=mime_type, data=image_bytes)
114-
115-
def webp_blob(image: PIL.Image.Image) -> protos.Blob:
116-
# Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
117-
image_io = io.BytesIO()
118-
image.save(image_io, format="webp", lossless=True)
119-
image_io.seek(0)
120-
121-
mime_type = "image/webp"
122-
image_bytes = image_io.read()
123-
124-
return protos.Blob(mime_type=mime_type, data=image_bytes)
125-
126-
return file_blob(image) or webp_blob(image)
127-
128-
129-
def image_to_blob(image: ImageType) -> protos.Blob:
130-
if PIL is not None:
131-
if isinstance(image, PIL.Image.Image):
132-
return _pil_to_blob(image)
133-
134-
if IPython is not None:
135-
if isinstance(image, IPython.display.Image):
136-
name = image.filename
137-
if name is None:
138-
raise ValueError(
139-
"Conversion failed. The `IPython.display.Image` can only be converted if "
140-
"it is constructed from a local file. Please ensure you are using the format: Image(filename='...')."
141-
)
142-
mime_type, _ = mimetypes.guess_type(name)
143-
if mime_type is None:
144-
mime_type = "image/unknown"
145-
146-
return protos.Blob(mime_type=mime_type, data=image.data)
147-
148-
raise TypeError(
149-
"Image conversion failed. The input was expected to be of type `Image` "
150-
"(either `PIL.Image.Image` or `IPython.display.Image`).\n"
151-
f"However, received an object of type: {type(image)}.\n"
152-
f"Object Value: {image}"
153-
)
154-
155-
15671
class BlobDict(TypedDict):
15772
mime_type: str
15873
data: bytes
@@ -189,12 +104,7 @@ def is_blob_dict(d):
189104
return "mime_type" in d and "data" in d
190105

191106

192-
if typing.TYPE_CHECKING:
193-
BlobType = Union[
194-
protos.Blob, BlobDict, PIL.Image.Image, IPython.display.Image
195-
] # Any for the images
196-
else:
197-
BlobType = Union[protos.Blob, BlobDict, Any]
107+
BlobType = Union[protos.Blob, BlobDict, _image_types.ImageType] # Any for the images
198108

199109

200110
def to_blob(blob: BlobType) -> protos.Blob:
@@ -203,8 +113,8 @@ def to_blob(blob: BlobType) -> protos.Blob:
203113

204114
if isinstance(blob, protos.Blob):
205115
return blob
206-
elif isinstance(blob, IMAGE_TYPES):
207-
return image_to_blob(blob)
116+
elif isinstance(blob, _image_types.IMAGE_TYPES):
117+
return _image_types.image_to_blob(blob)
208118
else:
209119
if isinstance(blob, Mapping):
210120
raise KeyError(
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from google.generativeai.types.image_types._image_types import *

0 commit comments

Comments
 (0)