Skip to content

Commit 4f05500

Browse files
authored
Update docstrings, fix jpeg support. (#108)
* Update docstrings, add png support. * black * Improve typechecking in docs
1 parent 0988543 commit 4f05500

File tree

6 files changed

+58
-25
lines changed

6 files changed

+58
-25
lines changed

docs/build_docs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
import pathlib
2525
import re
2626
import textwrap
27+
import typing
28+
29+
typing.TYPE_CHECKING = True
2730

2831
from absl import app
2932
from absl import flags

google/generativeai/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
from google.generativeai.embedding import embed_content
8383

8484
from google.generativeai.generative_models import GenerativeModel
85+
from google.generativeai.generative_models import ChatSession
8586

8687
from google.generativeai.text import generate_text
8788
from google.generativeai.text import generate_embeddings

google/generativeai/generative_models.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from google.generativeai.types import generation_types
1818
from google.generativeai.types import safety_types
1919

20-
_GENERATE_CONTENT_ASYNC_DOC = """The async version of `Model.generate_content`."""
20+
_GENERATE_CONTENT_ASYNC_DOC = """The async version of `GenerativeModel.generate_content`."""
2121

2222
_GENERATE_CONTENT_DOC = """A multipurpose function to generate responses from the model.
2323
24-
This `GenerativeModel.generate_content` method can handle multimodal input, and multiturn
24+
This `GenerativeModel.generate_content` method can handle multimodal input, and multi-turn
2525
conversations.
2626
2727
>>> model = genai.GenerativeModel('models/gemini-pro')
@@ -289,6 +289,15 @@ def start_chat(
289289
*,
290290
history: Iterable[content_types.StrictContentType] | None = None,
291291
) -> ChatSession:
292+
"""Returns a `genai.ChatSession` attached to this model.
293+
294+
>>> model = genai.GenerativeModel()
295+
>>> chat = model.start_chat(history=[...])
296+
>>> response = chat.send_message("Hello?")
297+
298+
Arguments:
299+
history: An iterable of `glm.Content` objects, or equvalents to initialize the session.
300+
"""
292301
if self._generation_config.get("candidate_count", 1) > 1:
293302
raise ValueError("Can't chat with `candidate_count > 1`")
294303
return ChatSession(

google/generativeai/types/content_types.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,23 @@
4343
]
4444

4545

46-
def pil_to_png_bytes(img):
46+
def pil_to_blob(img):
4747
bytesio = io.BytesIO()
48-
img.save(bytesio, format="PNG")
48+
if isinstance(img, PIL.PngImagePlugin.PngImageFile):
49+
img.save(bytesio, format="PNG")
50+
mime_type = "image/png"
51+
else:
52+
img.save(bytesio, format="JPEG")
53+
mime_type = "image/jpeg"
4954
bytesio.seek(0)
50-
return bytesio.read()
55+
data = bytesio.read()
56+
return glm.Blob(mime_type=mime_type, data=data)
5157

5258

5359
def image_to_blob(image) -> glm.Blob:
5460
if PIL is not None:
5561
if isinstance(image, PIL.Image.Image):
56-
return glm.Blob(mime_type="image/png", data=pil_to_png_bytes(image))
62+
return pil_to_blob(image)
5763

5864
if IPython is not None:
5965
if isinstance(image, IPython.display.Image):
@@ -71,7 +77,7 @@ def image_to_blob(image) -> glm.Blob:
7177
return glm.Blob(mime_type=mime_type, data=image.data)
7278

7379
raise TypeError(
74-
"Could not convert image. epected an `Image` type"
80+
"Could not convert image. expected an `Image` type"
7581
"(`PIL.Image.Image` or `IPython.display.Image`).\n"
7682
f"Got a: {type(image)}\n"
7783
f"Value: {image}"

tests/test_content.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,40 @@
1313

1414

1515
HERE = pathlib.Path(__file__).parent
16-
TEST_IMAGE_PATH = HERE / "test_img.png"
17-
TEST_IMAGE_URL = "https://storage.googleapis.com/generativeai-downloads/data/test_img.png"
18-
TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes()
16+
TEST_PNG_PATH = HERE / "test_img.png"
17+
TEST_PNG_URL = "https://storage.googleapis.com/generativeai-downloads/data/test_img.png"
18+
TEST_PNG_DATA = TEST_PNG_PATH.read_bytes()
19+
20+
TEST_JPG_PATH = HERE / "test_img.jpg"
21+
TEST_JPG_URL = "https://storage.googleapis.com/generativeai-downloads/data/test_img.jpg"
22+
TEST_JPG_DATA = TEST_JPG_PATH.read_bytes()
1923

2024

2125
class UnitTests(parameterized.TestCase):
2226
@parameterized.named_parameters(
23-
["PIL", PIL.Image.open(TEST_IMAGE_PATH)],
24-
["IPython", IPython.display.Image(filename=TEST_IMAGE_PATH)],
27+
["PIL", PIL.Image.open(TEST_PNG_PATH)],
28+
["IPython", IPython.display.Image(filename=TEST_PNG_PATH)],
2529
)
26-
def test_image_to_blob(self, image):
30+
def test_png_to_blob(self, image):
2731
blob = content_types.image_to_blob(image)
2832
self.assertIsInstance(blob, glm.Blob)
2933
self.assertEqual(blob.mime_type, "image/png")
3034
self.assertStartsWith(blob.data, b"\x89PNG")
3135

3236
@parameterized.named_parameters(
33-
["BlobDict", {"mime_type": "image/png", "data": TEST_IMAGE_DATA}],
34-
["glm.Blob", glm.Blob(mime_type="image/png", data=TEST_IMAGE_DATA)],
35-
["Image", IPython.display.Image(filename=TEST_IMAGE_PATH)],
37+
["PIL", PIL.Image.open(TEST_JPG_PATH)],
38+
["IPython", IPython.display.Image(filename=TEST_JPG_PATH)],
39+
)
40+
def test_jpg_to_blob(self, image):
41+
blob = content_types.image_to_blob(image)
42+
self.assertIsInstance(blob, glm.Blob)
43+
self.assertEqual(blob.mime_type, "image/jpeg")
44+
self.assertStartsWith(blob.data, b"\xff\xd8\xff\xe0\x00\x10JFIF")
45+
46+
@parameterized.named_parameters(
47+
["BlobDict", {"mime_type": "image/png", "data": TEST_PNG_DATA}],
48+
["glm.Blob", glm.Blob(mime_type="image/png", data=TEST_PNG_DATA)],
49+
["Image", IPython.display.Image(filename=TEST_PNG_PATH)],
3650
)
3751
def test_to_blob(self, example):
3852
blob = content_types.to_blob(example)
@@ -51,11 +65,11 @@ def test_to_part(self, example):
5165
self.assertEqual(part.text, "Hello world!")
5266

5367
@parameterized.named_parameters(
54-
["Image", IPython.display.Image(filename=TEST_IMAGE_PATH)],
55-
["BlobDict", {"mime_type": "image/png", "data": TEST_IMAGE_DATA}],
68+
["Image", IPython.display.Image(filename=TEST_PNG_PATH)],
69+
["BlobDict", {"mime_type": "image/png", "data": TEST_PNG_DATA}],
5670
[
5771
"PartDict",
58-
{"inline_data": {"mime_type": "image/png", "data": TEST_IMAGE_DATA}},
72+
{"inline_data": {"mime_type": "image/png", "data": TEST_PNG_DATA}},
5973
],
6074
)
6175
def test_img_to_part(self, example):
@@ -83,9 +97,9 @@ def test_to_content(self, example):
8397
self.assertEqual(part.text, "Hello world!")
8498

8599
@parameterized.named_parameters(
86-
["ContentDict", {"parts": [PIL.Image.open(TEST_IMAGE_PATH)]}],
87-
["list[Image]", [PIL.Image.open(TEST_IMAGE_PATH)]],
88-
["Image", PIL.Image.open(TEST_IMAGE_PATH)],
100+
["ContentDict", {"parts": [PIL.Image.open(TEST_PNG_PATH)]}],
101+
["list[Image]", [PIL.Image.open(TEST_PNG_PATH)]],
102+
["Image", PIL.Image.open(TEST_PNG_PATH)],
89103
)
90104
def test_img_to_content(self, example):
91105
content = content_types.to_content(example)
@@ -140,10 +154,10 @@ def test_dict_to_content_fails(self):
140154
@parameterized.named_parameters(
141155
[
142156
"ContentDict",
143-
[{"parts": [{"inline_data": PIL.Image.open(TEST_IMAGE_PATH)}]}],
157+
[{"parts": [{"inline_data": PIL.Image.open(TEST_PNG_PATH)}]}],
144158
],
145-
["ContentDict-unwraped", [{"parts": [PIL.Image.open(TEST_IMAGE_PATH)]}]],
146-
["Image", PIL.Image.open(TEST_IMAGE_PATH)],
159+
["ContentDict-unwraped", [{"parts": [PIL.Image.open(TEST_PNG_PATH)]}]],
160+
["Image", PIL.Image.open(TEST_PNG_PATH)],
147161
)
148162
def test_img_to_contents(self, example):
149163
contents = content_types.to_contents(example)

tests/test_img.jpg

4.13 KB
Loading

0 commit comments

Comments
 (0)