diff --git a/google/generativeai/responder.py b/google/generativeai/responder.py index bb85167ad..dd388c6a6 100644 --- a/google/generativeai/responder.py +++ b/google/generativeai/responder.py @@ -116,7 +116,7 @@ def _generate_schema( inspect.Parameter.POSITIONAL_ONLY, ) } - parameters = pydantic.create_model(f.__name__, **fields_dict).schema() + parameters = pydantic.create_model(f.__name__, **fields_dict).model_json_schema() # Postprocessing # 4. Suppress unnecessary title generation: # * https://github.com/pydantic/pydantic/issues/1051 diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index e2e2b680d..f80f145f5 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -379,7 +379,7 @@ def _schema_for_function( def _build_schema(fname, fields_dict): - parameters = pydantic.create_model(fname, **fields_dict).schema() + parameters = pydantic.create_model(fname, **fields_dict).model_json_schema() defs = parameters.pop("$defs", {}) # flatten the defs for name, value in defs.items(): @@ -401,7 +401,10 @@ def _build_schema(fname, fields_dict): def unpack_defs(schema, defs): - properties = schema["properties"] + properties = schema.get("properties", None) + if properties is None: + return + for name, value in properties.items(): ref_key = value.get("$ref", None) if ref_key is not None: diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index d4bed8b86..84689a922 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -16,7 +16,6 @@ import collections import contextlib -import sys from collections.abc import Iterable, AsyncIterable, Mapping import dataclasses import itertools @@ -165,7 +164,7 @@ class GenerationConfig: top_p: float | None = None top_k: int | None = None response_mime_type: str | None = None - response_schema: protos.Schema | Mapping[str, Any] | None = None + response_schema: protos.Schema | Mapping[str, Any] | type | None = None GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig] @@ -186,7 +185,8 @@ def _normalize_schema(generation_config): if not str(response_schema).startswith("list["): raise ValueError( f"Invalid input: Could not understand the type of '{response_schema}'. " - "Expected one of the following types: `int`, `float`, `str`, `bool`, `typing_extensions.TypedDict`, `dataclass`, or `list[...]`." + "Expected one of the following types: `int`, `float`, `str`, `bool`, `enum`, " + "`typing_extensions.TypedDict`, `dataclass` or `list[...]`." ) response_schema = content_types._schema_for_class(response_schema) diff --git a/samples/controlled_generation.py b/samples/controlled_generation.py index b0c269bb7..4942481f6 100644 --- a/samples/controlled_generation.py +++ b/samples/controlled_generation.py @@ -11,9 +11,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from absl.testing import absltest +import pathlib import google.generativeai as genai +media = pathlib.Path(__file__).parents[1] / "third_party" + class UnitTests(absltest.TestCase): def test_json_controlled_generation(self): @@ -22,6 +25,7 @@ def test_json_controlled_generation(self): class Recipe(typing.TypedDict): recipe_name: str + ingredients: list[str] model = genai.GenerativeModel("gemini-1.5-pro-latest") result = model.generate_content( @@ -36,14 +40,57 @@ class Recipe(typing.TypedDict): def test_json_no_schema(self): # [START json_no_schema] model = genai.GenerativeModel("gemini-1.5-pro-latest") - prompt = """List a few popular cookie recipes using this JSON schema: + prompt = """List a few popular cookie recipes in JSON format. + + Use this JSON schema: - Recipe = {'recipe_name': str} + Recipe = {'recipe_name': str, 'ingredients': list[str]} Return: list[Recipe]""" result = model.generate_content(prompt) print(result) # [END json_no_schema] + def test_json_enum(self): + # [START json_enum] + import enum + + class Choice(enum.Enum): + PERCUSSION = "Percussion" + STRING = "String" + WOODWIND = "Woodwind" + BRASS = "Brass" + KEYBOARD = "Keyboard" + + model = genai.GenerativeModel("gemini-1.5-pro-latest") + + organ = genai.upload_file(media / "organ.jpg") + result = model.generate_content( + ["What kind of instrument is this:", organ], + generation_config=genai.GenerationConfig( + response_mime_type="application/json", response_schema=Choice + ), + ) + print(result) # "Keyboard" + # [END json_enum] + + def test_json_enum_raw(self): + # [START json_enum_raw] + model = genai.GenerativeModel("gemini-1.5-pro-latest") + + organ = genai.upload_file(media / "organ.jpg") + result = model.generate_content( + ["What kind of instrument is this:", organ], + generation_config=genai.GenerationConfig( + response_mime_type="application/json", + response_schema={ + "type": "STRING", + "enum": ["Percussion", "String", "Woodwind", "Brass", "Keyboard"], + }, + ), + ) + print(result) # "Keyboard" + # [END json_enum_raw] + if __name__ == "__main__": absltest.main() diff --git a/tests/test_content.py b/tests/test_content.py index b52858bb8..676537b4e 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import dataclasses +import enum import pathlib import typing_extensions from typing import Any, Union, Iterable @@ -65,6 +66,18 @@ class ADataClassWithList: a: list[int] +class Choices(enum.Enum): + A = "a" + B = "b" + C = "c" + D = "d" + + +@dataclasses.dataclass +class HasEnum: + choice: Choices + + class UnitTests(parameterized.TestCase): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_PNG_PATH)], @@ -536,6 +549,25 @@ def b(): }, ), ], + ["enum", Choices, protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"])], + [ + "enum_list", + list[Choices], + protos.Schema( + type="ARRAY", + items=protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"]), + ), + ], + [ + "has_enum", + HasEnum, + protos.Schema( + type=protos.Type.OBJECT, + properties={ + "choice": protos.Schema(type=protos.Type.STRING, enum=["a", "b", "c", "d"]) + }, + ), + ], ) def test_auto_schema(self, annotation, expected): def fun(a: annotation):