Skip to content
Merged

Enum #529

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/generativeai/responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _generate_schema(
inspect.Parameter.POSITIONAL_ONLY,
)
}
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
parameters = pydantic.create_model(f.__name__, **fields_dict).model_json_schema()
# Postprocessing
# 4. Suppress unnecessary title generation:
# * https://github.com/pydantic/pydantic/issues/1051
Expand Down
7 changes: 5 additions & 2 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def _schema_for_function(


def _build_schema(fname, fields_dict):
parameters = pydantic.create_model(fname, **fields_dict).schema()
parameters = pydantic.create_model(fname, **fields_dict).model_json_schema()
defs = parameters.pop("$defs", {})
# flatten the defs
for name, value in defs.items():
Expand All @@ -401,7 +401,10 @@ def _build_schema(fname, fields_dict):


def unpack_defs(schema, defs):
properties = schema["properties"]
properties = schema.get("properties", None)
if properties is None:
return

for name, value in properties.items():
ref_key = value.get("$ref", None)
if ref_key is not None:
Expand Down
6 changes: 3 additions & 3 deletions google/generativeai/types/generation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import collections
import contextlib
import sys
from collections.abc import Iterable, AsyncIterable, Mapping
import dataclasses
import itertools
Expand Down Expand Up @@ -165,7 +164,7 @@ class GenerationConfig:
top_p: float | None = None
top_k: int | None = None
response_mime_type: str | None = None
response_schema: protos.Schema | Mapping[str, Any] | None = None
response_schema: protos.Schema | Mapping[str, Any] | type | None = None


GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig]
Expand All @@ -186,7 +185,8 @@ def _normalize_schema(generation_config):
if not str(response_schema).startswith("list["):
raise ValueError(
f"Invalid input: Could not understand the type of '{response_schema}'. "
"Expected one of the following types: `int`, `float`, `str`, `bool`, `typing_extensions.TypedDict`, `dataclass`, or `list[...]`."
"Expected one of the following types: `int`, `float`, `str`, `bool`, `enum`, "
"`typing_extensions.TypedDict`, `dataclass` or `list[...]`."
)
response_schema = content_types._schema_for_class(response_schema)

Expand Down
51 changes: 49 additions & 2 deletions samples/controlled_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import pathlib

import google.generativeai as genai

media = pathlib.Path(__file__).parents[1] / "third_party"


class UnitTests(absltest.TestCase):
def test_json_controlled_generation(self):
Expand All @@ -22,6 +25,7 @@ def test_json_controlled_generation(self):

class Recipe(typing.TypedDict):
recipe_name: str
ingredients: list[str]

model = genai.GenerativeModel("gemini-1.5-pro-latest")
result = model.generate_content(
Expand All @@ -36,14 +40,57 @@ class Recipe(typing.TypedDict):
def test_json_no_schema(self):
# [START json_no_schema]
model = genai.GenerativeModel("gemini-1.5-pro-latest")
prompt = """List a few popular cookie recipes using this JSON schema:
prompt = """List a few popular cookie recipes in JSON format.

Use this JSON schema:

Recipe = {'recipe_name': str}
Recipe = {'recipe_name': str, 'ingredients': list[str]}
Return: list[Recipe]"""
result = model.generate_content(prompt)
print(result)
# [END json_no_schema]

def test_json_enum(self):
# [START json_enum]
import enum

class Choice(enum.Enum):
PERCUSSION = "Percussion"
STRING = "String"
WOODWIND = "Woodwind"
BRASS = "Brass"
KEYBOARD = "Keyboard"

model = genai.GenerativeModel("gemini-1.5-pro-latest")

organ = genai.upload_file(media / "organ.jpg")
result = model.generate_content(
["What kind of instrument is this:", organ],
generation_config=genai.GenerationConfig(
response_mime_type="application/json", response_schema=Choice
),
)
print(result) # "Keyboard"
# [END json_enum]

def test_json_enum_raw(self):
# [START json_enum_raw]
model = genai.GenerativeModel("gemini-1.5-pro-latest")

organ = genai.upload_file(media / "organ.jpg")
result = model.generate_content(
["What kind of instrument is this:", organ],
generation_config=genai.GenerationConfig(
response_mime_type="application/json",
response_schema={
"type": "STRING",
"enum": ["Percussion", "String", "Woodwind", "Brass", "Keyboard"],
},
),
)
print(result) # "Keyboard"
# [END json_enum_raw]


if __name__ == "__main__":
absltest.main()
32 changes: 32 additions & 0 deletions tests/test_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import enum
import pathlib
import typing_extensions
from typing import Any, Union, Iterable
Expand Down Expand Up @@ -65,6 +66,18 @@ class ADataClassWithList:
a: list[int]


class Choices(enum.Enum):
A = "a"
B = "b"
C = "c"
D = "d"


@dataclasses.dataclass
class HasEnum:
choice: Choices


class UnitTests(parameterized.TestCase):
@parameterized.named_parameters(
["PIL", PIL.Image.open(TEST_PNG_PATH)],
Expand Down Expand Up @@ -536,6 +549,25 @@ def b():
},
),
],
["enum", Choices, protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"])],
[
"enum_list",
list[Choices],
protos.Schema(
type="ARRAY",
items=protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"]),
),
],
[
"has_enum",
HasEnum,
protos.Schema(
type=protos.Type.OBJECT,
properties={
"choice": protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"])
},
),
],
)
def test_auto_schema(self, annotation, expected):
def fun(a: annotation):
Expand Down
Loading