Skip to content

Commit aae0caf

Browse files
MarkDaoustmarkmcd
andauthored
watermark (#603)
* watermark Change-Id: I271993e07bfc034d89cd74013339046a21d0b472 * fix typing Change-Id: Ib026f3b66fb44c6ad35a15030b03204306b35443 * format Change-Id: I24304a80e6689b4fcec3155f7274e24db712b402 * fix test Change-Id: I68cb74f8ee5f224942417ea4e5fc4232ec688977 * make check_watermark a stand alone function Change-Id: I2d72620359dcc70fe8e720a14f78d83f75a42d90 * simplify typing. Change-Id: I1b901cc40b4b029cb09699fc6eac77690622b6e8 * Update google/generativeai/vision_models/_vision_models.py * Typo Co-authored-by: Mark McDonald <[email protected]> --------- Co-authored-by: Mark McDonald <[email protected]>
1 parent e09b902 commit aae0caf

File tree

4 files changed

+78
-5
lines changed

4 files changed

+78
-5
lines changed

google/generativeai/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979

8080
__version__ = version.__version__
8181

82-
del embedding
8382
del files
8483
del generative_models
8584
del models

google/generativeai/types/content_types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import IPython.display
3636

3737
IMAGE_TYPES = (PIL.Image.Image, IPython.display.Image)
38+
ImageType = PIL.Image.Image | IPython.display.Image
3839
else:
3940
IMAGE_TYPES = ()
4041
try:
@@ -52,6 +53,8 @@
5253
except ImportError:
5354
IPython = None
5455

56+
ImageType = Union["PIL.Image.Image", "IPython.display.Image"]
57+
5558

5659
__all__ = [
5760
"BlobDict",
@@ -123,7 +126,7 @@ def webp_blob(image: PIL.Image.Image) -> protos.Blob:
123126
return file_blob(image) or webp_blob(image)
124127

125128

126-
def image_to_blob(image) -> protos.Blob:
129+
def image_to_blob(image: ImageType) -> protos.Blob:
127130
if PIL is not None:
128131
if isinstance(image, PIL.Image.Image):
129132
return _pil_to_blob(image)

google/generativeai/vision_models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Classes for working with vision models."""
1616

1717
from google.generativeai.vision_models._vision_models import (
18+
check_watermark,
1819
Image,
1920
GeneratedImage,
2021
ImageGenerationModel,

google/generativeai/vision_models/_vision_models.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
import base64
1919
import collections
2020
import dataclasses
21-
import hashlib
2221
import io
2322
import json
23+
import os
2424
import pathlib
2525
import typing
2626
from typing import Any, Dict, List, Literal, Optional, Union
2727

2828
from google.generativeai import client
2929
from google.generativeai import protos
30+
from google.generativeai.types import content_types
3031

3132
from google.protobuf import struct_pb2
3233

@@ -110,6 +111,52 @@ def to_mapping_value(value) -> struct_pb2.Struct:
110111
PersonGeneration = Literal["dont_allow", "allow_adult"]
111112
PERSON_GENERATIONS = PersonGeneration.__args__ # type: ignore
112113

114+
ImageLikeType = Union["Image", pathlib.Path, content_types.ImageType]
115+
116+
117+
def check_watermark(
118+
img: ImageLikeType, model_id: str = "models/image-verification-001"
119+
) -> "CheckWatermarkResult":
120+
"""Checks if an image has a Google-AI watermark.
121+
122+
Args:
123+
img: can be a `pathlib.Path` or a `PIL.Image.Image`, `IPython.display.Image`, or `google.generativeai.Image`.
124+
model_id: Which version of the image-verification model to send the image to.
125+
126+
Returns:
127+
128+
"""
129+
if isinstance(img, Image):
130+
pass
131+
elif isinstance(img, pathlib.Path):
132+
img = Image.load_from_file(img)
133+
elif IPython_display is not None and isinstance(img, IPython_display.Image):
134+
img = Image(image_bytes=img.data)
135+
elif PIL_Image is not None and isinstance(img, PIL_Image.Image):
136+
blob = content_types._pil_to_blob(img)
137+
img = Image(image_bytes=blob.data)
138+
elif isinstance(img, protos.Blob):
139+
img = Image(image_bytes=img.data)
140+
else:
141+
raise TypeError(
142+
f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}"
143+
)
144+
145+
prediction_client = client.get_default_prediction_client()
146+
if not model_id.startswith("models/"):
147+
model_id = f"models/{model_id}"
148+
149+
instance = {"image": {"bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode()}}
150+
parameters = {"watermarkVerification": True}
151+
152+
# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
153+
pr = protos.PredictRequest.pb()
154+
request = pr(model=model_id, instances=[to_value(instance)], parameters=to_value(parameters))
155+
156+
response = prediction_client.predict(request)
157+
158+
return CheckWatermarkResult(response.predictions)
159+
113160

114161
class Image:
115162
"""Image."""
@@ -131,7 +178,7 @@ def __init__(
131178
self._image_bytes = image_bytes
132179

133180
@staticmethod
134-
def load_from_file(location: str) -> "Image":
181+
def load_from_file(location: os.PathLike) -> "Image":
135182
"""Loads image from local file or Google Cloud Storage.
136183
137184
Args:
@@ -206,6 +253,29 @@ def _as_base64_string(self) -> str:
206253
def _repr_png_(self):
207254
return self._pil_image._repr_png_() # type:ignore
208255

256+
check_watermark = check_watermark
257+
258+
259+
class CheckWatermarkResult:
260+
def __init__(self, predictions):
261+
self._predictions = predictions
262+
263+
@property
264+
def decision(self):
265+
return self._predictions[0]["decision"]
266+
267+
def __str__(self):
268+
return f"CheckWatermarkResult([{{'decision': {self.decision!r}}}])"
269+
270+
def __bool__(self):
271+
decision = self.decision
272+
if decision == "ACCEPT":
273+
return True
274+
elif decision == "REJECT":
275+
return False
276+
else:
277+
raise ValueError(f"Unrecognized result: {decision}")
278+
209279

210280
class ImageGenerationModel:
211281
"""Generates images from text prompt.
@@ -479,7 +549,7 @@ def generation_parameters(self):
479549
return self._generation_parameters
480550

481551
@staticmethod
482-
def load_from_file(location: str) -> "GeneratedImage":
552+
def load_from_file(location: os.PathLike) -> "GeneratedImage":
483553
"""Loads image from file.
484554
485555
Args:

0 commit comments

Comments
 (0)