18
18
import base64
19
19
import collections
20
20
import dataclasses
21
- import hashlib
22
21
import io
23
22
import json
23
+ import os
24
24
import pathlib
25
25
import typing
26
26
from typing import Any , Dict , List , Literal , Optional , Union
27
27
28
28
from google .generativeai import client
29
29
from google .generativeai import protos
30
+ from google .generativeai .types import content_types
30
31
31
32
from google .protobuf import struct_pb2
32
33
@@ -110,6 +111,8 @@ def to_mapping_value(value) -> struct_pb2.Struct:
110
111
PersonGeneration = Literal ["dont_allow" , "allow_adult" ]
111
112
PERSON_GENERATIONS = PersonGeneration .__args__ # type: ignore
112
113
114
+ ImageLikeType = Union ["Image" , pathlib .Path (), content_types .ImageType ]
115
+
113
116
114
117
class Image :
115
118
"""Image."""
@@ -131,7 +134,7 @@ def __init__(
131
134
self ._image_bytes = image_bytes
132
135
133
136
@staticmethod
134
- def load_from_file (location : str ) -> "Image" :
137
+ def load_from_file (location : os . PathLike ) -> "Image" :
135
138
"""Loads image from local file or Google Cloud Storage.
136
139
137
140
Args:
@@ -206,6 +209,63 @@ def _as_base64_string(self) -> str:
206
209
def _repr_png_ (self ):
207
210
return self ._pil_image ._repr_png_ () # type:ignore
208
211
212
+ def check_watermark (self : ImageLikeType , model_id : str = "models/image-verification-001" ):
213
+ img = None
214
+ if isinstance (self , Image ):
215
+ img = self
216
+ elif isinstance (self , pathlib .Path ):
217
+ img = Image .load_from_file (self )
218
+ elif IPython_display is not None and isinstance (self , IPython_display .Image ):
219
+ img = Image (image_bytes = self .data )
220
+ elif PIL_Image is not None and isinstance (self , PIL_Image .Image ):
221
+ blob = content_types ._pil_to_blob (self )
222
+ img = Image (image_bytes = blob .data )
223
+ elif isinstance (self , protos .Blob ):
224
+ img = Image (image_bytes = self .data )
225
+ else :
226
+ raise TypeError (
227
+ f"Not implemented: Could not convert a { type (img )} into `Image`\n { img = } "
228
+ )
229
+
230
+ prediction_client = client .get_default_prediction_client ()
231
+ if not model_id .startswith ("models/" ):
232
+ model_id = f"models/{ model_id } "
233
+
234
+ # Note: Only a single prompt is supported by the service.
235
+ instance = {"image" : {"bytesBase64Encoded" : base64 .b64encode (img ._loaded_bytes ).decode ()}}
236
+ parameters = {"watermarkVerification" : True }
237
+
238
+ # This is to get around https://github.com/googleapis/proto-plus-python/issues/488
239
+ pr = protos .PredictRequest .pb ()
240
+ request = pr (
241
+ model = model_id , instances = [to_value (instance )], parameters = to_value (parameters )
242
+ )
243
+
244
+ response = prediction_client .predict (request )
245
+
246
+ return CheckWatermarkResult (response .predictions )
247
+
248
+
249
+ class CheckWatermarkResult :
250
+ def __init__ (self , predictions ):
251
+ self ._predictions = predictions
252
+
253
+ @property
254
+ def decision (self ):
255
+ return self ._predictions [0 ]["decision" ]
256
+
257
+ def __str__ (self ):
258
+ return f"CheckWatermarkResult([{{'decision': { self .decision !r} }}])"
259
+
260
+ def __bool__ (self ):
261
+ decision = self .decision
262
+ if decision == "ACCEPT" :
263
+ return True
264
+ elif decision == "REJECT" :
265
+ return False
266
+ else :
267
+ raise ValueError ("Unrecognized result" )
268
+
209
269
210
270
class ImageGenerationModel :
211
271
"""Generates images from text prompt.
0 commit comments