Skip to content

Commit e0928fc

Browse files
authored
Enum (#529)
* try to support enum types Change-Id: I5141f751c4d6c578ef957aa8250cb26309ea9bd3 * format Change-Id: I9619654247f0f7230c8ba4c76035ad0ff9324fd4 * Be clear that test uses enum value. Change-Id: I03e319f2795c7c15f527316a145d021620936c57 * Add samples Change-Id: Ifc5e5b2039c9f0532d37386f6d7b136961943bac * Fix type annotations. Change-Id: I6b7b769cf0ba17fc7188518cdcec3085f59760b0
1 parent e805b24 commit e0928fc

File tree

5 files changed

+90
-8
lines changed

5 files changed

+90
-8
lines changed

google/generativeai/responder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _generate_schema(
116116
inspect.Parameter.POSITIONAL_ONLY,
117117
)
118118
}
119-
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
119+
parameters = pydantic.create_model(f.__name__, **fields_dict).model_json_schema()
120120
# Postprocessing
121121
# 4. Suppress unnecessary title generation:
122122
# * https://github.com/pydantic/pydantic/issues/1051

google/generativeai/types/content_types.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def _schema_for_function(
402402

403403

404404
def _build_schema(fname, fields_dict):
405-
parameters = pydantic.create_model(fname, **fields_dict).schema()
405+
parameters = pydantic.create_model(fname, **fields_dict).model_json_schema()
406406
defs = parameters.pop("$defs", {})
407407
# flatten the defs
408408
for name, value in defs.items():
@@ -424,7 +424,10 @@ def _build_schema(fname, fields_dict):
424424

425425

426426
def unpack_defs(schema, defs):
427-
properties = schema["properties"]
427+
properties = schema.get("properties", None)
428+
if properties is None:
429+
return
430+
428431
for name, value in properties.items():
429432
ref_key = value.get("$ref", None)
430433
if ref_key is not None:

google/generativeai/types/generation_types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import collections
1818
import contextlib
19-
import sys
2019
from collections.abc import Iterable, AsyncIterable, Mapping
2120
import dataclasses
2221
import itertools
@@ -165,7 +164,7 @@ class GenerationConfig:
165164
top_p: float | None = None
166165
top_k: int | None = None
167166
response_mime_type: str | None = None
168-
response_schema: protos.Schema | Mapping[str, Any] | None = None
167+
response_schema: protos.Schema | Mapping[str, Any] | type | None = None
169168

170169

171170
GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig]
@@ -186,7 +185,8 @@ def _normalize_schema(generation_config):
186185
if not str(response_schema).startswith("list["):
187186
raise ValueError(
188187
f"Invalid input: Could not understand the type of '{response_schema}'. "
189-
"Expected one of the following types: `int`, `float`, `str`, `bool`, `typing_extensions.TypedDict`, `dataclass`, or `list[...]`."
188+
"Expected one of the following types: `int`, `float`, `str`, `bool`, `enum`, "
189+
"`typing_extensions.TypedDict`, `dataclass` or `list[...]`."
190190
)
191191
response_schema = content_types._schema_for_class(response_schema)
192192

samples/controlled_generation.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
from absl.testing import absltest
14+
import pathlib
1415

1516
import google.generativeai as genai
1617

18+
media = pathlib.Path(__file__).parents[1] / "third_party"
19+
1720

1821
class UnitTests(absltest.TestCase):
1922
def test_json_controlled_generation(self):
@@ -22,6 +25,7 @@ def test_json_controlled_generation(self):
2225

2326
class Recipe(typing.TypedDict):
2427
recipe_name: str
28+
ingredients: list[str]
2529

2630
model = genai.GenerativeModel("gemini-1.5-pro-latest")
2731
result = model.generate_content(
@@ -36,14 +40,57 @@ class Recipe(typing.TypedDict):
3640
def test_json_no_schema(self):
3741
# [START json_no_schema]
3842
model = genai.GenerativeModel("gemini-1.5-pro-latest")
39-
prompt = """List a few popular cookie recipes using this JSON schema:
43+
prompt = """List a few popular cookie recipes in JSON format.
44+
45+
Use this JSON schema:
4046
41-
Recipe = {'recipe_name': str}
47+
Recipe = {'recipe_name': str, 'ingredients': list[str]}
4248
Return: list[Recipe]"""
4349
result = model.generate_content(prompt)
4450
print(result)
4551
# [END json_no_schema]
4652

53+
def test_json_enum(self):
54+
# [START json_enum]
55+
import enum
56+
57+
class Choice(enum.Enum):
58+
PERCUSSION = "Percussion"
59+
STRING = "String"
60+
WOODWIND = "Woodwind"
61+
BRASS = "Brass"
62+
KEYBOARD = "Keyboard"
63+
64+
model = genai.GenerativeModel("gemini-1.5-pro-latest")
65+
66+
organ = genai.upload_file(media / "organ.jpg")
67+
result = model.generate_content(
68+
["What kind of instrument is this:", organ],
69+
generation_config=genai.GenerationConfig(
70+
response_mime_type="application/json", response_schema=Choice
71+
),
72+
)
73+
print(result) # "Keyboard"
74+
# [END json_enum]
75+
76+
def test_json_enum_raw(self):
77+
# [START json_enum_raw]
78+
model = genai.GenerativeModel("gemini-1.5-pro-latest")
79+
80+
organ = genai.upload_file(media / "organ.jpg")
81+
result = model.generate_content(
82+
["What kind of instrument is this:", organ],
83+
generation_config=genai.GenerationConfig(
84+
response_mime_type="application/json",
85+
response_schema={
86+
"type": "STRING",
87+
"enum": ["Percussion", "String", "Woodwind", "Brass", "Keyboard"],
88+
},
89+
),
90+
)
91+
print(result) # "Keyboard"
92+
# [END json_enum_raw]
93+
4794

4895
if __name__ == "__main__":
4996
absltest.main()

tests/test_content.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import dataclasses
16+
import enum
1617
import pathlib
1718
import typing_extensions
1819
from typing import Any, Union, Iterable
@@ -69,6 +70,18 @@ class ADataClassWithList:
6970
a: list[int]
7071

7172

73+
class Choices(enum.Enum):
74+
A = "a"
75+
B = "b"
76+
C = "c"
77+
D = "d"
78+
79+
80+
@dataclasses.dataclass
81+
class HasEnum:
82+
choice: Choices
83+
84+
7285
class UnitTests(parameterized.TestCase):
7386
@parameterized.named_parameters(
7487
["PIL", PIL.Image.open(TEST_PNG_PATH)],
@@ -551,6 +564,25 @@ def b():
551564
},
552565
),
553566
],
567+
["enum", Choices, protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"])],
568+
[
569+
"enum_list",
570+
list[Choices],
571+
protos.Schema(
572+
type="ARRAY",
573+
items=protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"]),
574+
),
575+
],
576+
[
577+
"has_enum",
578+
HasEnum,
579+
protos.Schema(
580+
type=protos.Type.OBJECT,
581+
properties={
582+
"choice": protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"])
583+
},
584+
),
585+
],
554586
)
555587
def test_auto_schema(self, annotation, expected):
556588
def fun(a: annotation):

0 commit comments

Comments
 (0)