Skip to content

Commit e8d69a8

Browse files
committed
watermark
Change-Id: I271993e07bfc034d89cd74013339046a21d0b472
1 parent e09b902 commit e8d69a8

File tree

3 files changed

+66
-4
lines changed

3 files changed

+66
-4
lines changed

google/generativeai/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979

8080
__version__ = version.__version__
8181

82-
del embedding
8382
del files
8483
del generative_models
8584
del models

google/generativeai/types/content_types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import IPython.display
3636

3737
IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
38+
ImageType = PIL.Image.Image | IPython.display.Image
3839
else:
3940
IMAGE_TYPES = ()
4041
try:
@@ -52,6 +53,8 @@
5253
except ImportError:
5354
IPython = None
5455

56+
ImageType = Union[*IMAGE_TYPES]
57+
5558

5659
__all__ = [
5760
"BlobDict",
@@ -123,7 +126,7 @@ def webp_blob(image: PIL.Image.Image) -> protos.Blob:
123126
return file_blob(image) or webp_blob(image)
124127

125128

126-
def image_to_blob(image) -> protos.Blob:
129+
def image_to_blob(image: ImageType) -> protos.Blob:
127130
if PIL is not None:
128131
if isinstance(image, PIL.Image.Image):
129132
return _pil_to_blob(image)

google/generativeai/vision_models/_vision_models.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
import base64
1919
import collections
2020
import dataclasses
21-
import hashlib
2221
import io
2322
import json
23+
import os
2424
import pathlib
2525
import typing
2626
from typing import Any, Dict, List, Literal, Optional, Union
2727

2828
from google.generativeai import client
2929
from google.generativeai import protos
30+
from google.generativeai.types import content_types
3031

3132
from google.protobuf import struct_pb2
3233

@@ -110,6 +111,8 @@ def to_mapping_value(value) -> struct_pb2.Struct:
110111
PersonGeneration = Literal["dont_allow", "allow_adult"]
111112
PERSON_GENERATIONS = PersonGeneration.__args__ # type: ignore
112113

114+
ImageLikeType = Union["Image", pathlib.Path(), content_types.ImageType]
115+
113116

114117
class Image:
115118
"""Image."""
@@ -131,7 +134,7 @@ def __init__(
131134
self._image_bytes = image_bytes
132135

133136
@staticmethod
134-
def load_from_file(location: str) -> "Image":
137+
def load_from_file(location: os.PathLike) -> "Image":
135138
"""Loads image from local file or Google Cloud Storage.
136139
137140
Args:
@@ -206,6 +209,63 @@ def _as_base64_string(self) -> str:
206209
def _repr_png_(self):
207210
return self._pil_image._repr_png_() # type:ignore
208211

212+
def check_watermark(self: ImageLikeType, model_id: str = "models/image-verification-001"):
213+
img = None
214+
if isinstance(self, Image):
215+
img = self
216+
elif isinstance(self, pathlib.Path):
217+
img = Image.load_from_file(self)
218+
elif IPython_display is not None and isinstance(self, IPython_display.Image):
219+
img = Image(image_bytes=self.data)
220+
elif PIL_Image is not None and isinstance(self, PIL_Image.Image):
221+
blob = content_types._pil_to_blob(self)
222+
img = Image(image_bytes=blob.data)
223+
elif isinstance(self, protos.Blob):
224+
img = Image(image_bytes=self.data)
225+
else:
226+
raise TypeError(
227+
f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}"
228+
)
229+
230+
prediction_client = client.get_default_prediction_client()
231+
if not model_id.startswith("models/"):
232+
model_id = f"models/{model_id}"
233+
234+
# Note: Only a single prompt is supported by the service.
235+
instance = {"image": {"bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode()}}
236+
parameters = {"watermarkVerification": True}
237+
238+
# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
239+
pr = protos.PredictRequest.pb()
240+
request = pr(
241+
model=model_id, instances=[to_value(instance)], parameters=to_value(parameters)
242+
)
243+
244+
response = prediction_client.predict(request)
245+
246+
return CheckWatermarkResult(response.predictions)
247+
248+
249+
class CheckWatermarkResult:
250+
def __init__(self, predictions):
251+
self._predictions = predictions
252+
253+
@property
254+
def decision(self):
255+
return self._predictions[0]["decision"]
256+
257+
def __str__(self):
258+
return f"CheckWatermarkResult([{{'decision': {self.decision!r}}}])"
259+
260+
def __bool__(self):
261+
decision = self.decision
262+
if decision == "ACCEPT":
263+
return True
264+
elif decision == "REJECT":
265+
return False
266+
else:
267+
raise ValueError("Unrecognized result")
268+
209269

210270
class ImageGenerationModel:
211271
"""Generates images from text prompt.

0 commit comments

Comments
 (0)