18
18
import base64
19
19
import collections
20
20
import dataclasses
21
- import hashlib
22
21
import io
23
22
import json
23
+ import os
24
24
import pathlib
25
25
import typing
26
26
from typing import Any , Dict , List , Literal , Optional , Union
27
27
28
28
from google .generativeai import client
29
29
from google .generativeai import protos
30
+ from google .generativeai .types import content_types
30
31
31
32
from google .protobuf import struct_pb2
32
33
@@ -110,6 +111,52 @@ def to_mapping_value(value) -> struct_pb2.Struct:
110
111
PersonGeneration = Literal ["dont_allow" , "allow_adult" ]
111
112
PERSON_GENERATIONS = PersonGeneration .__args__ # type: ignore
112
113
114
+ ImageLikeType = Union ["Image" , pathlib .Path , content_types .ImageType ]
115
+
116
+
117
+ def check_watermark (
118
+ img : ImageLikeType , model_id : str = "models/image-verification-001"
119
+ ) -> "CheckWatermarkResult" :
120
+ """Checks if an image has a Google-AI watermark.
121
+
122
+ Args:
123
+ img: can be a `pathlib.Path` or a `PIL.Image.Image`, `IPython.display.Image`, or `google.generativeai.Image`.
124
+ model_id: Which version of the image-verification model to send the image to.
125
+
126
+ Returns:
127
+
128
+ """
129
+ if isinstance (img , Image ):
130
+ pass
131
+ elif isinstance (img , pathlib .Path ):
132
+ img = Image .load_from_file (img )
133
+ elif IPython_display is not None and isinstance (img , IPython_display .Image ):
134
+ img = Image (image_bytes = img .data )
135
+ elif PIL_Image is not None and isinstance (img , PIL_Image .Image ):
136
+ blob = content_types ._pil_to_blob (img )
137
+ img = Image (image_bytes = blob .data )
138
+ elif isinstance (img , protos .Blob ):
139
+ img = Image (image_bytes = img .data )
140
+ else :
141
+ raise TypeError (
142
+ f"Not implemented: Could not convert a { type (img )} into `Image`\n { img = } "
143
+ )
144
+
145
+ prediction_client = client .get_default_prediction_client ()
146
+ if not model_id .startswith ("models/" ):
147
+ model_id = f"models/{ model_id } "
148
+
149
+ instance = {"image" : {"bytesBase64Encoded" : base64 .b64encode (img ._loaded_bytes ).decode ()}}
150
+ parameters = {"watermarkVerification" : True }
151
+
152
+ # This is to get around https://github.com/googleapis/proto-plus-python/issues/488
153
+ pr = protos .PredictRequest .pb ()
154
+ request = pr (model = model_id , instances = [to_value (instance )], parameters = to_value (parameters ))
155
+
156
+ response = prediction_client .predict (request )
157
+
158
+ return CheckWatermarkResult (response .predictions )
159
+
113
160
114
161
class Image :
115
162
"""Image."""
@@ -131,7 +178,7 @@ def __init__(
131
178
self ._image_bytes = image_bytes
132
179
133
180
@staticmethod
134
- def load_from_file (location : str ) -> "Image" :
181
+ def load_from_file (location : os . PathLike ) -> "Image" :
135
182
"""Loads image from local file or Google Cloud Storage.
136
183
137
184
Args:
@@ -206,6 +253,29 @@ def _as_base64_string(self) -> str:
206
253
def _repr_png_ (self ):
207
254
return self ._pil_image ._repr_png_ () # type:ignore
208
255
256
+ check_watermark = check_watermark
257
+
258
+
259
+ class CheckWatermarkResult :
260
+ def __init__ (self , predictions ):
261
+ self ._predictions = predictions
262
+
263
+ @property
264
+ def decision (self ):
265
+ return self ._predictions [0 ]["decision" ]
266
+
267
+ def __str__ (self ):
268
+ return f"CheckWatermarkResult([{{'decision': { self .decision !r} }}])"
269
+
270
+ def __bool__ (self ):
271
+ decision = self .decision
272
+ if decision == "ACCEPT" :
273
+ return True
274
+ elif decision == "REJECT" :
275
+ return False
276
+ else :
277
+ raise ValueError (f"Unrecognized result: { decision } " )
278
+
209
279
210
280
class ImageGenerationModel :
211
281
"""Generates images from text prompt.
@@ -479,7 +549,7 @@ def generation_parameters(self):
479
549
return self ._generation_parameters
480
550
481
551
@staticmethod
482
- def load_from_file (location : str ) -> "GeneratedImage" :
552
+ def load_from_file (location : os . PathLike ) -> "GeneratedImage" :
483
553
"""Loads image from file.
484
554
485
555
Args:
0 commit comments