Skip to content

Commit 21268e3

Browse files
committed
add files
Change-Id: Ie7f91cef171c1f813b52ff1b2a4daedf7ea19edd
1 parent 170df6a commit 21268e3

File tree

2 files changed

+333
-0
lines changed

2 files changed

+333
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from google.generativeai.types.image_types._image_types import *
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
import base64
2+
import io
3+
import json
4+
import mimetypes
5+
import os
6+
import pathlib
7+
import typing
8+
from typing import Any, Dict, Optional, Union
9+
10+
from google.generativeai import protos
11+
from google.generativeai import client
12+
13+
# pylint: disable=g-import-not-at-top
14+
if typing.TYPE_CHECKING:
15+
import PIL.Image
16+
import PIL.ImageFile
17+
import IPython.display
18+
19+
IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
20+
ImageType = PIL.Image.Image | IPython.display.Image
21+
else:
22+
IMAGE_TYPES = ()
23+
try:
24+
import PIL.Image
25+
import PIL.ImageFile
26+
27+
IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,)
28+
except ImportError:
29+
PIL = None
30+
31+
try:
32+
import IPython.display
33+
34+
IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,)
35+
except ImportError:
36+
IPython = None
37+
38+
ImageType = Union["Image", "PIL.Image.Image", "IPython.display.Image"]
39+
# pylint: enable=g-import-not-at-top
40+
41+
__all__ = ["Image", "GeneratedImage", "check_watermark", "CheckWatermarkResult", "ImageType"]
42+
43+
44+
def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
45+
# If the image is a local file, return a file-based blob without any modification.
46+
# Otherwise, return a lossless WebP blob (same quality with optimized size).
47+
def file_blob(image: PIL.Image.Image) -> protos.Blob | None:
48+
if not isinstance(image, PIL.ImageFile.ImageFile) or image.filename is None:
49+
return None
50+
filename = str(image.filename)
51+
if not pathlib.Path(filename).is_file():
52+
return None
53+
54+
mime_type = image.get_format_mimetype()
55+
image_bytes = pathlib.Path(filename).read_bytes()
56+
57+
return protos.Blob(mime_type=mime_type, data=image_bytes)
58+
59+
def webp_blob(image: PIL.Image.Image) -> protos.Blob:
60+
# Reference: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
61+
image_io = io.BytesIO()
62+
image.save(image_io, format="webp", lossless=True)
63+
image_io.seek(0)
64+
65+
mime_type = "image/webp"
66+
image_bytes = image_io.read()
67+
68+
return protos.Blob(mime_type=mime_type, data=image_bytes)
69+
70+
return file_blob(image) or webp_blob(image)
71+
72+
73+
def image_to_blob(image: ImageType) -> protos.Blob:
74+
if PIL is not None:
75+
if isinstance(image, PIL.Image.Image):
76+
return _pil_to_blob(image)
77+
78+
if IPython is not None:
79+
if isinstance(image, IPython.display.Image):
80+
name = image.filename
81+
if name is None:
82+
raise ValueError(
83+
"Conversion failed. The `IPython.display.Image` can only be converted if "
84+
"it is constructed from a local file. Please ensure you are using the format: Image(filename='...')."
85+
)
86+
mime_type, _ = mimetypes.guess_type(name)
87+
if mime_type is None:
88+
mime_type = "image/unknown"
89+
90+
return protos.Blob(mime_type=mime_type, data=image.data)
91+
92+
if isinstance(image, Image):
93+
return protos.Blob(mime_type=image._mime_type, data=image._image_bytes)
94+
95+
raise TypeError(
96+
"Image conversion failed. The input was expected to be of type `Image` "
97+
"(either `PIL.Image.Image` or `IPython.display.Image`).\n"
98+
f"However, received an object of type: {type(image)}.\n"
99+
f"Object Value: {image}"
100+
)
101+
102+
103+
class CheckWatermarkResult:
104+
def __init__(self, predictions):
105+
self._predictions = predictions
106+
107+
@property
108+
def decision(self):
109+
return self._predictions[0]["decision"]
110+
111+
def __str__(self):
112+
return f"CheckWatermarkResult([{{'decision': {self.decision!r}}}])"
113+
114+
def __bool__(self):
115+
decision = self.decision
116+
if decision == "ACCEPT":
117+
return True
118+
elif decision == "REJECT":
119+
return False
120+
else:
121+
raise ValueError("Unrecognized result")
122+
123+
124+
def check_watermark(
125+
img: pathlib.Path | ImageType, model_id: str = "models/image-verification-001"
126+
) -> "CheckWatermarkResult":
127+
"""Checks if an image has a Google-AI watermark.
128+
129+
Args:
130+
img: can be a `pathlib.Path` or a `PIL.Image.Image`, `IPythin.display.Image`, or `google.generativeai.Image`.
131+
model_id: Which version of the image-verification model to send the image to.
132+
133+
Returns:
134+
135+
"""
136+
if isinstance(img, Image):
137+
pass
138+
elif isinstance(img, pathlib.Path):
139+
img = Image.load_from_file(img)
140+
elif IPython.display is not None and isinstance(img, IPython.display.Image):
141+
img = Image(image_bytes=img.data)
142+
elif PIL.Image is not None and isinstance(img, PIL.Image.Image):
143+
blob = _pil_to_blob(img)
144+
img = Image(image_bytes=blob.data)
145+
elif isinstance(img, protos.Blob):
146+
img = Image(image_bytes=img.data)
147+
else:
148+
raise TypeError(
149+
f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}"
150+
)
151+
152+
prediction_client = client.get_default_prediction_client()
153+
if not model_id.startswith("models/"):
154+
model_id = f"models/{model_id}"
155+
156+
instance = {"image": {"bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode()}}
157+
parameters = {"watermarkVerification": True}
158+
159+
response = prediction_client.predict(
160+
model=model_id, instances=[instance], parameters=parameters
161+
)
162+
163+
return CheckWatermarkResult(response.predictions)
164+
165+
166+
class Image:
167+
"""Image."""
168+
169+
__module__ = "vertexai.vision_models"
170+
171+
_loaded_bytes: Optional[bytes] = None
172+
_loaded_image: Optional["PIL_Image.Image"] = None
173+
174+
def __init__(
175+
self,
176+
image_bytes: Optional[bytes],
177+
):
178+
"""Creates an `Image` object.
179+
180+
Args:
181+
image_bytes: Image file bytes. Image can be in PNG or JPEG format.
182+
"""
183+
self._image_bytes = image_bytes
184+
185+
@staticmethod
186+
def load_from_file(location: os.PathLike) -> "Image":
187+
"""Loads image from local file or Google Cloud Storage.
188+
189+
Args:
190+
location: Local path or Google Cloud Storage uri from where to load
191+
the image.
192+
193+
Returns:
194+
Loaded image as an `Image` object.
195+
"""
196+
# Load image from local path
197+
image_bytes = pathlib.Path(location).read_bytes()
198+
image = Image(image_bytes=image_bytes)
199+
return image
200+
201+
@property
202+
def _image_bytes(self) -> bytes:
203+
return self._loaded_bytes
204+
205+
@_image_bytes.setter
206+
def _image_bytes(self, value: bytes):
207+
self._loaded_bytes = value
208+
209+
@property
210+
def _pil_image(self) -> "PIL_Image.Image": # type: ignore
211+
if self._loaded_image is None:
212+
if not PIL:
213+
raise RuntimeError(
214+
"The PIL module is not available. Please install the Pillow package."
215+
)
216+
self._loaded_image = PIL.Image.open(io.BytesIO(self._image_bytes))
217+
return self._loaded_image
218+
219+
@property
220+
def _size(self):
221+
return self._pil_image.size
222+
223+
@property
224+
def _mime_type(self) -> str:
225+
"""Returns the MIME type of the image."""
226+
import PIL
227+
228+
return PIL.Image.MIME.get(self._pil_image.format, "image/jpeg")
229+
230+
def show(self):
231+
"""Shows the image.
232+
233+
This method only works when in a notebook environment.
234+
"""
235+
if PIL and IPython:
236+
IPython.display.display(self._pil_image)
237+
238+
def save(self, location: str):
239+
"""Saves image to a file.
240+
241+
Args:
242+
location: Local path where to save the image.
243+
"""
244+
pathlib.Path(location).write_bytes(self._image_bytes)
245+
246+
def _as_base64_string(self) -> str:
247+
"""Encodes image using the base64 encoding.
248+
249+
Returns:
250+
Base64 encoding of the image as a string.
251+
"""
252+
# ! b64encode returns `bytes` object, not `str`.
253+
# We need to convert `bytes` to `str`, otherwise we get service error:
254+
# "received initial metadata size exceeds limit"
255+
return base64.b64encode(self._image_bytes).decode("ascii")
256+
257+
def _repr_png_(self):
258+
return self._pil_image._repr_png_() # type:ignore
259+
260+
check_watermark = check_watermark
261+
262+
263+
_EXIF_USER_COMMENT_TAG_IDX = 0x9286
264+
_IMAGE_GENERATION_PARAMETERS_EXIF_KEY = (
265+
"google.cloud.vertexai.image_generation.image_generation_parameters"
266+
)
267+
268+
269+
class GeneratedImage(Image):
270+
"""Generated image."""
271+
272+
__module__ = "google.generativeai"
273+
274+
def __init__(
275+
self,
276+
image_bytes: Optional[bytes],
277+
generation_parameters: Dict[str, Any],
278+
):
279+
"""Creates a `GeneratedImage` object.
280+
281+
Args:
282+
image_bytes: Image file bytes. Image can be in PNG or JPEG format.
283+
generation_parameters: Image generation parameter values.
284+
"""
285+
super().__init__(image_bytes=image_bytes)
286+
self._generation_parameters = generation_parameters
287+
288+
@property
289+
def generation_parameters(self):
290+
"""Image generation parameters as a dictionary."""
291+
return self._generation_parameters
292+
293+
@staticmethod
294+
def load_from_file(location: os.PathLike) -> "GeneratedImage":
295+
"""Loads image from file.
296+
297+
Args:
298+
location: Local path from where to load the image.
299+
300+
Returns:
301+
Loaded image as a `GeneratedImage` object.
302+
"""
303+
base_image = Image.load_from_file(location=location)
304+
exif = base_image._pil_image.getexif() # pylint: disable=protected-access
305+
exif_comment_dict = json.loads(exif[_EXIF_USER_COMMENT_TAG_IDX])
306+
generation_parameters = exif_comment_dict[_IMAGE_GENERATION_PARAMETERS_EXIF_KEY]
307+
return GeneratedImage(
308+
image_bytes=base_image._image_bytes, # pylint: disable=protected-access
309+
generation_parameters=generation_parameters,
310+
)
311+
312+
def save(self, location: str, include_generation_parameters: bool = True):
313+
"""Saves image to a file.
314+
315+
Args:
316+
location: Local path where to save the image.
317+
include_generation_parameters: Whether to include the image
318+
generation parameters in the image's EXIF metadata.
319+
"""
320+
if include_generation_parameters:
321+
if not self._generation_parameters:
322+
raise ValueError("Image does not have generation parameters.")
323+
if not PIL:
324+
raise ValueError("The PIL module is required for saving generation parameters.")
325+
326+
exif = self._pil_image.getexif()
327+
exif[_EXIF_USER_COMMENT_TAG_IDX] = json.dumps(
328+
{_IMAGE_GENERATION_PARAMETERS_EXIF_KEY: self._generation_parameters}
329+
)
330+
self._pil_image.save(location, exif=exif)
331+
else:
332+
super().save(location=location)

0 commit comments

Comments
 (0)