Skip to content

Commit 17b5d56

Browse files
committed
make check_watermark a stand alone function
Change-Id: I2d72620359dcc70fe8e720a14f78d83f75a42d90
1 parent 63e0501 commit 17b5d56

File tree

2 files changed

+47
-36
lines changed

2 files changed

+47
-36
lines changed

google/generativeai/vision_models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Classes for working with vision models."""
1616

1717
from google.generativeai.vision_models._vision_models import (
18+
check_watermark,
1819
Image,
1920
GeneratedImage,
2021
ImageGenerationModel,

google/generativeai/vision_models/_vision_models.py

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,50 @@ def to_mapping_value(value) -> struct_pb2.Struct:
114114
ImageLikeType = Union["Image", pathlib.Path, content_types.ImageType]
115115

116116

117+
def check_watermark(
118+
img: ImageLikeType, model_id: str = "models/image-verification-001"
119+
) -> "CheckWatermarkResult":
120+
"""Checks if an image has a Google-AI watermark.
121+
122+
Args:
123+
img: can be a `pathlib.Path` or a `PIL.Image.Image`, `IPythin.display.Image`, or `google.generativeai.Image`.
124+
model_id: Which version of the image-verification model to send the image to.
125+
126+
Returns:
127+
128+
"""
129+
if isinstance(img, Image):
130+
pass
131+
elif isinstance(img, pathlib.Path):
132+
img = Image.load_from_file(img)
133+
elif IPython_display is not None and isinstance(img, IPython_display.Image):
134+
img = Image(image_bytes=img.data)
135+
elif PIL_Image is not None and isinstance(img, PIL_Image.Image):
136+
blob = content_types._pil_to_blob(img)
137+
img = Image(image_bytes=blob.data)
138+
elif isinstance(img, protos.Blob):
139+
img = Image(image_bytes=img.data)
140+
else:
141+
raise TypeError(
142+
f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}"
143+
)
144+
145+
prediction_client = client.get_default_prediction_client()
146+
if not model_id.startswith("models/"):
147+
model_id = f"models/{model_id}"
148+
149+
instance = {"image": {"bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode()}}
150+
parameters = {"watermarkVerification": True}
151+
152+
# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
153+
pr = protos.PredictRequest.pb()
154+
request = pr(model=model_id, instances=[to_value(instance)], parameters=to_value(parameters))
155+
156+
response = prediction_client.predict(request)
157+
158+
return CheckWatermarkResult(response.predictions)
159+
160+
117161
class Image:
118162
"""Image."""
119163

@@ -209,41 +253,7 @@ def _as_base64_string(self) -> str:
209253
def _repr_png_(self):
210254
return self._pil_image._repr_png_() # type:ignore
211255

212-
def check_watermark(self: ImageLikeType, model_id: str = "models/image-verification-001"):
213-
img = None
214-
if isinstance(self, Image):
215-
img = self
216-
elif isinstance(self, pathlib.Path):
217-
img = Image.load_from_file(self)
218-
elif IPython_display is not None and isinstance(self, IPython_display.Image):
219-
img = Image(image_bytes=self.data)
220-
elif PIL_Image is not None and isinstance(self, PIL_Image.Image):
221-
blob = content_types._pil_to_blob(self)
222-
img = Image(image_bytes=blob.data)
223-
elif isinstance(self, protos.Blob):
224-
img = Image(image_bytes=self.data)
225-
else:
226-
raise TypeError(
227-
f"Not implemented: Could not convert a {type(img)} into `Image`\n {img=}"
228-
)
229-
230-
prediction_client = client.get_default_prediction_client()
231-
if not model_id.startswith("models/"):
232-
model_id = f"models/{model_id}"
233-
234-
# Note: Only a single prompt is supported by the service.
235-
instance = {"image": {"bytesBase64Encoded": base64.b64encode(img._loaded_bytes).decode()}}
236-
parameters = {"watermarkVerification": True}
237-
238-
# This is to get around https://github.com/googleapis/proto-plus-python/issues/488
239-
pr = protos.PredictRequest.pb()
240-
request = pr(
241-
model=model_id, instances=[to_value(instance)], parameters=to_value(parameters)
242-
)
243-
244-
response = prediction_client.predict(request)
245-
246-
return CheckWatermarkResult(response.predictions)
256+
check_watermark = check_watermark
247257

248258

249259
class CheckWatermarkResult:
@@ -539,7 +549,7 @@ def generation_parameters(self):
539549
return self._generation_parameters
540550

541551
@staticmethod
542-
def load_from_file(location: str) -> "GeneratedImage":
552+
def load_from_file(location: os.PathLike) -> "GeneratedImage":
543553
"""Loads image from file.
544554
545555
Args:

0 commit comments

Comments
 (0)