Skip to content

Commit a96feda

Browse files
Add response_schema parameter (#302)
* Add response_schema parameter * Update types for response_schema * fix type Change-Id: I90e9c4218f041687c3b50e620305b6eff09b650a * Update type to Mapping[str, Any] * Update Any import * Add black . format check * check black . precheck * Remove seed parameter for now * Update google/generativeai/types/generation_types.py * Added test cases for response_schema, function for normalizing schema, and enums for type field in schema --------- Co-authored-by: Mark Daoust <[email protected]>
1 parent c165b20 commit a96feda

File tree

3 files changed

+124
-6
lines changed

3 files changed

+124
-6
lines changed

google/generativeai/responder.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,47 @@
2424

2525
from google.ai import generativelanguage as glm
2626

27+
Type = glm.Type
28+
29+
TypeOptions = Union[int, str, Type]
30+
31+
_TYPE_TYPE: dict[TypeOptions, Type] = {
32+
Type.TYPE_UNSPECIFIED: Type.TYPE_UNSPECIFIED,
33+
0: Type.TYPE_UNSPECIFIED,
34+
"type_unspecified": Type.TYPE_UNSPECIFIED,
35+
"unspecified": Type.TYPE_UNSPECIFIED,
36+
Type.STRING: Type.STRING,
37+
1: Type.STRING,
38+
"type_string": Type.STRING,
39+
"string": Type.STRING,
40+
Type.NUMBER: Type.NUMBER,
41+
2: Type.NUMBER,
42+
"type_number": Type.NUMBER,
43+
"number": Type.NUMBER,
44+
Type.INTEGER: Type.INTEGER,
45+
3: Type.INTEGER,
46+
"type_integer": Type.INTEGER,
47+
"integer": Type.INTEGER,
48+
Type.BOOLEAN: Type.BOOLEAN,
49+
4: Type.INTEGER,
50+
"type_boolean": Type.BOOLEAN,
51+
"boolean": Type.BOOLEAN,
52+
Type.ARRAY: Type.ARRAY,
53+
5: Type.ARRAY,
54+
"type_array": Type.ARRAY,
55+
"array": Type.ARRAY,
56+
Type.OBJECT: Type.OBJECT,
57+
6: Type.OBJECT,
58+
"type_object": Type.OBJECT,
59+
"object": Type.OBJECT,
60+
}
61+
62+
63+
def to_type(x: TypeOptions) -> Type:
64+
if isinstance(x, str):
65+
x = x.lower()
66+
return _TYPE_TYPE[x]
67+
2768

2869
def _generate_schema(
2970
f: Callable[..., Any],
@@ -115,15 +156,18 @@ def _generate_schema(
115156
return schema
116157

117158

118-
def _rename_schema_fields(schema):
159+
def _rename_schema_fields(schema: dict[str, Any]):
119160
if schema is None:
120161
return schema
121162

122163
schema = schema.copy()
123164

124165
type_ = schema.pop("type", None)
125166
if type_ is not None:
126-
schema["type_"] = type_.upper()
167+
schema["type_"] = type_
168+
type_ = schema.get("type_", None)
169+
if type_ is not None:
170+
schema["type_"] = to_type(type_)
127171

128172
format_ = schema.pop("format", None)
129173
if format_ is not None:

google/generativeai/types/generation_types.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,22 @@
1616

1717
import collections
1818
import contextlib
19-
from collections.abc import Iterable, AsyncIterable
19+
import sys
20+
from collections.abc import Iterable, AsyncIterable, Mapping
2021
import dataclasses
2122
import itertools
2223
import json
2324
import sys
2425
import textwrap
25-
from typing import Union
26+
from typing import Union, Any
2627
from typing_extensions import TypedDict
2728

2829
import google.protobuf.json_format
2930
import google.api_core.exceptions
3031

3132
from google.ai import generativelanguage as glm
3233
from google.generativeai import string_utils
34+
from google.generativeai.responder import _rename_schema_fields
3335

3436
__all__ = [
3537
"AsyncGenerateContentResponse",
@@ -81,6 +83,7 @@ class GenerationConfigDict(TypedDict, total=False):
8183
max_output_tokens: int
8284
temperature: float
8385
response_mime_type: str
86+
response_schema: glm.Schema | Mapping[str, Any] # fmt: off
8487

8588

8689
@dataclasses.dataclass
@@ -147,6 +150,10 @@ class GenerationConfig:
147150
Supported mimetype:
148151
`text/plain`: (default) Text output.
149152
`application/json`: JSON response in the candidates.
153+
154+
response_schema:
155+
Optional. Specifies the format of the JSON requested if response_mime_type is
156+
`application/json`.
150157
"""
151158

152159
candidate_count: int | None = None
@@ -156,21 +163,41 @@ class GenerationConfig:
156163
top_p: float | None = None
157164
top_k: int | None = None
158165
response_mime_type: str | None = None
166+
response_schema: glm.Schema | Mapping[str, Any] | None = None
159167

160168

161169
GenerationConfigType = Union[glm.GenerationConfig, GenerationConfigDict, GenerationConfig]
162170

163171

172+
def _normalize_schema(generation_config):
173+
# Convert response_schema to glm.Schema for request
174+
response_schema = generation_config.get("response_schema", None)
175+
if response_schema is None:
176+
return
177+
if isinstance(response_schema, glm.Schema):
178+
return
179+
response_schema = _rename_schema_fields(response_schema)
180+
generation_config["response_schema"] = glm.Schema(response_schema)
181+
182+
164183
def to_generation_config_dict(generation_config: GenerationConfigType):
165184
if generation_config is None:
166185
return {}
167186
elif isinstance(generation_config, glm.GenerationConfig):
168-
return type(generation_config).to_dict(generation_config) # pytype: disable=attribute-error
187+
schema = generation_config.response_schema
188+
generation_config = type(generation_config).to_dict(
189+
generation_config
190+
) # pytype: disable=attribute-error
191+
generation_config["response_schema"] = schema
192+
return generation_config
169193
elif isinstance(generation_config, GenerationConfig):
170194
generation_config = dataclasses.asdict(generation_config)
195+
_normalize_schema(generation_config)
171196
return {key: value for key, value in generation_config.items() if value is not None}
172197
elif hasattr(generation_config, "keys"):
173-
return dict(generation_config)
198+
generation_config = dict(generation_config)
199+
_normalize_schema(generation_config)
200+
return generation_config
174201
else:
175202
raise TypeError(
176203
"Did not understand `generation_config`, expected a `dict` or"

tests/test_generation.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,53 @@ def test_repr_for_generate_content_response_from_iterator(self):
561561
)
562562
self.assertEqual(expected, result)
563563

564+
@parameterized.named_parameters(
565+
[
566+
"glm.GenerationConfig",
567+
glm.GenerationConfig(
568+
temperature=0.1,
569+
stop_sequences=["end"],
570+
response_mime_type="application/json",
571+
response_schema=glm.Schema(
572+
type="STRING", format="float", description="This is an example schema."
573+
),
574+
),
575+
],
576+
[
577+
"GenerationConfigDict",
578+
{
579+
"temperature": 0.1,
580+
"stop_sequences": ["end"],
581+
"response_mime_type": "application/json",
582+
"response_schema": glm.Schema(
583+
type="STRING", format="float", description="This is an example schema."
584+
),
585+
},
586+
],
587+
[
588+
"GenerationConfig",
589+
generation_types.GenerationConfig(
590+
temperature=0.1,
591+
stop_sequences=["end"],
592+
response_mime_type="application/json",
593+
response_schema=glm.Schema(
594+
type="STRING", format="float", description="This is an example schema."
595+
),
596+
),
597+
],
598+
)
599+
def test_response_schema(self, config):
600+
gd = generation_types.to_generation_config_dict(config)
601+
self.assertIsInstance(gd, dict)
602+
self.assertEqual(gd["temperature"], 0.1)
603+
self.assertEqual(gd["stop_sequences"], ["end"])
604+
self.assertEqual(gd["response_mime_type"], "application/json")
605+
actual = gd["response_schema"]
606+
expected = glm.Schema(
607+
type="STRING", format="float", description="This is an example schema."
608+
)
609+
self.assertEqual(actual, expected)
610+
564611

565612
if __name__ == "__main__":
566613
absltest.main()

0 commit comments

Comments
 (0)