Skip to content

Commit d009f13

Browse files
committed
Merge branch 'main' into text-error
Change-Id: Ia01fef6de0c0787461127c0112cbfe4e8c4cf8d8
2 parents e069551 + e0928fc commit d009f13

File tree

6 files changed

+129
-9
lines changed

6 files changed

+129
-9
lines changed

google/generativeai/responder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _generate_schema(
116116
inspect.Parameter.POSITIONAL_ONLY,
117117
)
118118
}
119-
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
119+
parameters = pydantic.create_model(f.__name__, **fields_dict).model_json_schema()
120120
# Postprocessing
121121
# 4. Suppress unnecessary title generation:
122122
# * https://github.com/pydantic/pydantic/issues/1051

google/generativeai/types/content_types.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,34 @@
7373

7474

7575
def pil_to_blob(img):
76+
# When you load an image with PIL you get a subclass of PIL.Image
77+
# The subclass knows what file type it was loaded from it has a `.format` class attribute
78+
# and the `get_format_mimetype` method. Convert these back to the same file type.
79+
#
80+
# The base image class doesn't know its file type, it just knows its mode.
81+
# RGBA converts to PNG easily, P[allet] converts to GIF, RGB to GIF.
82+
# But for anything else I'm not going to bother mapping it out (for now) let's just convert to RGB and send it.
83+
#
84+
# References:
85+
# - file formats: https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html
86+
# - image modes: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#modes
87+
7688
bytesio = io.BytesIO()
77-
if isinstance(img, PIL.PngImagePlugin.PngImageFile) or img.mode == "RGBA":
89+
90+
get_mime = getattr(img, "get_format_mimetype", None)
91+
if get_mime is not None:
92+
# If the image is created from a file, convert back to the same file type.
93+
img.save(bytesio, format=img.format)
94+
mime_type = img.get_format_mimetype()
95+
elif img.mode == "RGBA":
7896
img.save(bytesio, format="PNG")
7997
mime_type = "image/png"
98+
elif img.mode == "P":
99+
img.save(bytesio, format="GIF")
100+
mime_type = "image/gif"
80101
else:
102+
if img.mode != "RGB":
103+
img = img.convert("RGB")
81104
img.save(bytesio, format="JPEG")
82105
mime_type = "image/jpeg"
83106
bytesio.seek(0)
@@ -379,7 +402,7 @@ def _schema_for_function(
379402

380403

381404
def _build_schema(fname, fields_dict):
382-
parameters = pydantic.create_model(fname, **fields_dict).schema()
405+
parameters = pydantic.create_model(fname, **fields_dict).model_json_schema()
383406
defs = parameters.pop("$defs", {})
384407
# flatten the defs
385408
for name, value in defs.items():
@@ -401,7 +424,10 @@ def _build_schema(fname, fields_dict):
401424

402425

403426
def unpack_defs(schema, defs):
404-
properties = schema["properties"]
427+
properties = schema.get("properties", None)
428+
if properties is None:
429+
return
430+
405431
for name, value in properties.items():
406432
ref_key = value.get("$ref", None)
407433
if ref_key is not None:

google/generativeai/types/generation_types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import collections
1818
import contextlib
19-
import sys
2019
from collections.abc import Iterable, AsyncIterable, Mapping
2120
import dataclasses
2221
import itertools
@@ -165,7 +164,7 @@ class GenerationConfig:
165164
top_p: float | None = None
166165
top_k: int | None = None
167166
response_mime_type: str | None = None
168-
response_schema: protos.Schema | Mapping[str, Any] | None = None
167+
response_schema: protos.Schema | Mapping[str, Any] | type | None = None
169168

170169

171170
GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig]
@@ -186,7 +185,8 @@ def _normalize_schema(generation_config):
186185
if not str(response_schema).startswith("list["):
187186
raise ValueError(
188187
f"Invalid input: Could not understand the type of '{response_schema}'. "
189-
"Expected one of the following types: `int`, `float`, `str`, `bool`, `typing_extensions.TypedDict`, `dataclass`, or `list[...]`."
188+
"Expected one of the following types: `int`, `float`, `str`, `bool`, `enum`, "
189+
"`typing_extensions.TypedDict`, `dataclass` or `list[...]`."
190190
)
191191
response_schema = content_types._schema_for_class(response_schema)
192192

samples/controlled_generation.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
from absl.testing import absltest
14+
import pathlib
1415

1516
import google.generativeai as genai
1617

18+
media = pathlib.Path(__file__).parents[1] / "third_party"
19+
1720

1821
class UnitTests(absltest.TestCase):
1922
def test_json_controlled_generation(self):
@@ -22,6 +25,7 @@ def test_json_controlled_generation(self):
2225

2326
class Recipe(typing.TypedDict):
2427
recipe_name: str
28+
ingredients: list[str]
2529

2630
model = genai.GenerativeModel("gemini-1.5-pro-latest")
2731
result = model.generate_content(
@@ -36,14 +40,57 @@ class Recipe(typing.TypedDict):
3640
def test_json_no_schema(self):
3741
# [START json_no_schema]
3842
model = genai.GenerativeModel("gemini-1.5-pro-latest")
39-
prompt = """List a few popular cookie recipes using this JSON schema:
43+
prompt = """List a few popular cookie recipes in JSON format.
44+
45+
Use this JSON schema:
4046
41-
Recipe = {'recipe_name': str}
47+
Recipe = {'recipe_name': str, 'ingredients': list[str]}
4248
Return: list[Recipe]"""
4349
result = model.generate_content(prompt)
4450
print(result)
4551
# [END json_no_schema]
4652

53+
def test_json_enum(self):
54+
# [START json_enum]
55+
import enum
56+
57+
class Choice(enum.Enum):
58+
PERCUSSION = "Percussion"
59+
STRING = "String"
60+
WOODWIND = "Woodwind"
61+
BRASS = "Brass"
62+
KEYBOARD = "Keyboard"
63+
64+
model = genai.GenerativeModel("gemini-1.5-pro-latest")
65+
66+
organ = genai.upload_file(media / "organ.jpg")
67+
result = model.generate_content(
68+
["What kind of instrument is this:", organ],
69+
generation_config=genai.GenerationConfig(
70+
response_mime_type="application/json", response_schema=Choice
71+
),
72+
)
73+
print(result) # "Keyboard"
74+
# [END json_enum]
75+
76+
def test_json_enum_raw(self):
77+
# [START json_enum_raw]
78+
model = genai.GenerativeModel("gemini-1.5-pro-latest")
79+
80+
organ = genai.upload_file(media / "organ.jpg")
81+
result = model.generate_content(
82+
["What kind of instrument is this:", organ],
83+
generation_config=genai.GenerationConfig(
84+
response_mime_type="application/json",
85+
response_schema={
86+
"type": "STRING",
87+
"enum": ["Percussion", "String", "Woodwind", "Brass", "Keyboard"],
88+
},
89+
),
90+
)
91+
print(result) # "Keyboard"
92+
# [END json_enum_raw]
93+
4794

4895
if __name__ == "__main__":
4996
absltest.main()

tests/test_content.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import dataclasses
16+
import enum
1617
import pathlib
1718
import typing_extensions
1819
from typing import Any, Union, Iterable
@@ -35,6 +36,10 @@
3536
TEST_JPG_URL = "https://storage.googleapis.com/generativeai-downloads/data/test_img.jpg"
3637
TEST_JPG_DATA = TEST_JPG_PATH.read_bytes()
3738

39+
TEST_GIF_PATH = HERE / "test_img.gif"
40+
TEST_GIF_URL = "https://storage.googleapis.com/generativeai-downloads/data/test_img.gif"
41+
TEST_GIF_DATA = TEST_GIF_PATH.read_bytes()
42+
3843

3944
# simple test function
4045
def datetime():
@@ -65,6 +70,18 @@ class ADataClassWithList:
6570
a: list[int]
6671

6772

73+
class Choices(enum.Enum):
74+
A = "a"
75+
B = "b"
76+
C = "c"
77+
D = "d"
78+
79+
80+
@dataclasses.dataclass
81+
class HasEnum:
82+
choice: Choices
83+
84+
6885
class UnitTests(parameterized.TestCase):
6986
@parameterized.named_parameters(
7087
["PIL", PIL.Image.open(TEST_PNG_PATH)],
@@ -88,6 +105,17 @@ def test_jpg_to_blob(self, image):
88105
self.assertEqual(blob.mime_type, "image/jpeg")
89106
self.assertStartsWith(blob.data, b"\xff\xd8\xff\xe0\x00\x10JFIF")
90107

108+
@parameterized.named_parameters(
109+
["PIL", PIL.Image.open(TEST_GIF_PATH)],
110+
["P", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8)).convert("P")],
111+
["IPython", IPython.display.Image(filename=TEST_GIF_PATH)],
112+
)
113+
def test_gif_to_blob(self, image):
114+
blob = content_types.image_to_blob(image)
115+
self.assertIsInstance(blob, protos.Blob)
116+
self.assertEqual(blob.mime_type, "image/gif")
117+
self.assertStartsWith(blob.data, b"GIF87a")
118+
91119
@parameterized.named_parameters(
92120
["BlobDict", {"mime_type": "image/png", "data": TEST_PNG_DATA}],
93121
["protos.Blob", protos.Blob(mime_type="image/png", data=TEST_PNG_DATA)],
@@ -536,6 +564,25 @@ def b():
536564
},
537565
),
538566
],
567+
["enum", Choices, protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"])],
568+
[
569+
"enum_list",
570+
list[Choices],
571+
protos.Schema(
572+
type="ARRAY",
573+
items=protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"]),
574+
),
575+
],
576+
[
577+
"has_enum",
578+
HasEnum,
579+
protos.Schema(
580+
type=protos.Type.OBJECT,
581+
properties={
582+
"choice": protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"])
583+
},
584+
),
585+
],
539586
)
540587
def test_auto_schema(self, annotation, expected):
541588
def fun(a: annotation):

tests/test_img.gif

353 Bytes
Loading

0 commit comments

Comments
 (0)