Skip to content

Commit 836d31a

Browse files
authored
Add more enum samples. (#543)
* Add more enum samples Change-Id: I743d5967cc1cc91576b8ddf5a60db1767d94508d * format Change-Id: I8f6f9389f1cae0a7c934217968d4e2e20bb9590e
1 parent 4647e79 commit 836d31a

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

samples/controlled_generation.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,34 @@ class Choice(enum.Enum):
7373
print(result) # "Keyboard"
7474
# [END json_enum]
7575

76+
def test_enum_in_json(self):
77+
# [START enum_in_json]
78+
import enum
79+
from typing_extensions import TypedDict
80+
81+
class Grade(enum.Enum):
82+
A_PLUS = "a+"
83+
A = "a"
84+
B = "b"
85+
C = "c"
86+
D = "d"
87+
F = "f"
88+
89+
class Recipe(TypedDict):
90+
recipe_name: str
91+
grade: Grade
92+
93+
model = genai.GenerativeModel("gemini-1.5-pro-latest")
94+
95+
result = model.generate_content(
96+
"List about 10 cookie recipes, grade them based on popularity",
97+
generation_config=genai.GenerationConfig(
98+
response_mime_type="application/json", response_schema=list[Recipe]
99+
),
100+
)
101+
print(result) # [{"grade": "a+", "recipe_name": "Chocolate Chip Cookies"}, ...]
102+
# [END enum_in_json]
103+
76104
def test_json_enum_raw(self):
77105
# [START json_enum_raw]
78106
model = genai.GenerativeModel("gemini-1.5-pro-latest")
@@ -91,6 +119,47 @@ def test_json_enum_raw(self):
91119
print(result) # "Keyboard"
92120
# [END json_enum_raw]
93121

122+
def test_x_enum(self):
123+
# [START x_enum]
124+
import enum
125+
126+
class Choice(enum.Enum):
127+
PERCUSSION = "Percussion"
128+
STRING = "String"
129+
WOODWIND = "Woodwind"
130+
BRASS = "Brass"
131+
KEYBOARD = "Keyboard"
132+
133+
model = genai.GenerativeModel("gemini-1.5-pro-latest")
134+
135+
organ = genai.upload_file(media / "organ.jpg")
136+
result = model.generate_content(
137+
["What kind of instrument is this:", organ],
138+
generation_config=genai.GenerationConfig(
139+
response_mime_type="text/x.enum", response_schema=Choice
140+
),
141+
)
142+
print(result) # "Keyboard"
143+
# [END x_enum]
144+
145+
def test_x_enum_raw(self):
146+
# [START x_enum_raw]
147+
model = genai.GenerativeModel("gemini-1.5-pro-latest")
148+
149+
organ = genai.upload_file(media / "organ.jpg")
150+
result = model.generate_content(
151+
["What kind of instrument is this:", organ],
152+
generation_config=genai.GenerationConfig(
153+
response_mime_type="text/x.enum",
154+
response_schema={
155+
"type": "STRING",
156+
"enum": ["Percussion", "String", "Woodwind", "Brass", "Keyboard"],
157+
},
158+
),
159+
)
160+
print(result) # "Keyboard"
161+
# [END x_enum_raw]
162+
94163

95164
if __name__ == "__main__":
96165
absltest.main()

0 commit comments

Comments
 (0)