Skip to content

Commit bfca8b7

Browse files
committed
try to support enum types
Change-Id: I5141f751c4d6c578ef957aa8250cb26309ea9bd3
1 parent 32b754f commit bfca8b7

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

google/generativeai/responder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _generate_schema(
116116
inspect.Parameter.POSITIONAL_ONLY,
117117
)
118118
}
119-
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
119+
parameters = pydantic.create_model(f.__name__, **fields_dict).model_json_schema()
120120
# Postprocessing
121121
# 4. Suppress unnecessary title generation:
122122
# * https://github.com/pydantic/pydantic/issues/1051

google/generativeai/types/content_types.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def _schema_for_function(
379379

380380

381381
def _build_schema(fname, fields_dict):
382-
parameters = pydantic.create_model(fname, **fields_dict).schema()
382+
parameters = pydantic.create_model(fname, **fields_dict).model_json_schema()
383383
defs = parameters.pop("$defs", {})
384384
# flatten the defs
385385
for name, value in defs.items():
@@ -401,7 +401,10 @@ def _build_schema(fname, fields_dict):
401401

402402

403403
def unpack_defs(schema, defs):
404-
properties = schema["properties"]
404+
properties = schema.get("properties", None)
405+
if properties is None:
406+
return
407+
405408
for name, value in properties.items():
406409
ref_key = value.get("$ref", None)
407410
if ref_key is not None:

tests/test_content.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import dataclasses
16+
import enum
1617
import pathlib
1718
import typing_extensions
1819
from typing import Any, Union, Iterable
@@ -64,6 +65,18 @@ class ADataClassWithNullable:
6465
class ADataClassWithList:
6566
a: list[int]
6667

68+
class Choices(enum.Enum):
69+
A:str = "A"
70+
B:str = "B"
71+
C:str = "C"
72+
D:str = "D"
73+
74+
75+
@dataclasses.dataclass
76+
class HasEnum:
77+
choice:Choices
78+
79+
6780

6881
class UnitTests(parameterized.TestCase):
6982
@parameterized.named_parameters(
@@ -536,6 +549,20 @@ def b():
536549
},
537550
),
538551
],
552+
['enum', Choices, protos.Schema(type=protos.Type.STRING, enum=['A', 'B', 'C', 'D'])],
553+
['enum_list',
554+
list[Choices],
555+
protos.Schema(
556+
type="ARRAY",
557+
items=protos.Schema(type=protos.Type.STRING, enum=['A', 'B', 'C', 'D']),
558+
),
559+
],
560+
['has_enum',
561+
HasEnum,
562+
protos.Schema(
563+
type=protos.Type.OBJECT,
564+
properties={'choice': protos.Schema(type=protos.Type.STRING, enum=['A', 'B', 'C', 'D'])})
565+
]
539566
)
540567
def test_auto_schema(self, annotation, expected):
541568
def fun(a: annotation):

0 commit comments

Comments
 (0)