diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index e2e2b680d..71d1e9907 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -73,11 +73,34 @@ def pil_to_blob(img): + # When you load an image with PIL you get a subclass of PIL.Image + # The subclass knows what file type it was loaded from it has a `.format` class attribute + # and the `get_format_mimetype` method. Convert these back to the same file type. + # + # The base image class doesn't know its file type, it just knows its mode. + # RGBA converts to PNG easily, P[allet] converts to GIF, RGB to GIF. + # But for anything else I'm not going to bother mapping it out (for now) let's just convert to RGB and send it. + # + # References: + # - file formats: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html + # - image modes: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes + bytesio = io.BytesIO() - if isinstance(img, PIL.PngImagePlugin.PngImageFile) or img.mode == "RGBA": + + get_mime = getattr(img, "get_format_mimetype", None) + if get_mime is not None: + # If the image is created from a file, convert back to the same file type. + img.save(bytesio, format=img.format) + mime_type = img.get_format_mimetype() + elif img.mode == "RGBA": img.save(bytesio, format="PNG") mime_type = "image/png" + elif img.mode == "P": + img.save(bytesio, format="GIF") + mime_type = "image/gif" else: + if img.mode != "RGB": + img = img.convert("RGB") img.save(bytesio, format="JPEG") mime_type = "image/jpeg" bytesio.seek(0) diff --git a/tests/test_content.py b/tests/test_content.py index b52858bb8..da763dc33 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -35,6 +35,10 @@ TEST_JPG_URL = "https://storage.googleapis.com/generativeai-downloads/data/test_img.jpg" TEST_JPG_DATA = TEST_JPG_PATH.read_bytes() +TEST_GIF_PATH = HERE / "test_img.gif" +TEST_GIF_URL = "https://storage.googleapis.com/generativeai-downloads/data/test_img.gif" +TEST_GIF_DATA = TEST_GIF_PATH.read_bytes() + # simple test function def datetime(): @@ -88,6 +92,17 @@ def test_jpg_to_blob(self, image): self.assertEqual(blob.mime_type, "image/jpeg") self.assertStartsWith(blob.data, b"\xff\xd8\xff\xe0\x00\x10JFIF") + @parameterized.named_parameters( + ["PIL", PIL.Image.open(TEST_GIF_PATH)], + ["P", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8)).convert("P")], + ["IPython", IPython.display.Image(filename=TEST_GIF_PATH)], + ) + def test_gif_to_blob(self, image): + blob = content_types.image_to_blob(image) + self.assertIsInstance(blob, protos.Blob) + self.assertEqual(blob.mime_type, "image/gif") + self.assertStartsWith(blob.data, b"GIF87a") + @parameterized.named_parameters( ["BlobDict", {"mime_type": "image/png", "data": TEST_PNG_DATA}], ["protos.Blob", protos.Blob(mime_type="image/png", data=TEST_PNG_DATA)], diff --git a/tests/test_img.gif b/tests/test_img.gif new file mode 100644 index 000000000..66c81ac7a Binary files /dev/null and b/tests/test_img.gif differ