Skip to content

Commit 165aeb0

Browse files
committed
Add more enum samples
Change-Id: I743d5967cc1cc91576b8ddf5a60db1767d94508d
1 parent e0928fc commit 165aeb0

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

samples/controlled_generation.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,36 @@ class Choice(enum.Enum):
7373
print(result) # "Keyboard"
7474
# [END json_enum]
7575

76+
def test_enum_in_json(self):
77+
# [START enum_in_json]
78+
import enum
79+
from typing_extensions import TypedDict
80+
81+
class Grade(enum.Enum):
82+
A_PLUS = "a+"
83+
A = "a"
84+
B = "b"
85+
C = "c"
86+
D = "d"
87+
F = "f"
88+
89+
class Recipe(TypedDict):
90+
recipe_name: str
91+
grade: Grade
92+
93+
94+
model = genai.GenerativeModel("gemini-1.5-pro-latest")
95+
96+
result = model.generate_content(
97+
"List about 10 cookie recipes, grade them based on popularity",
98+
generation_config=genai.GenerationConfig(
99+
response_mime_type="application/json",
100+
response_schema=list[Recipe]
101+
),
102+
)
103+
print(result) # [{"grade": "a+", "recipe_name": "Chocolate Chip Cookies"}, ...]
104+
# [END enum_in_json]
105+
76106
def test_json_enum_raw(self):
77107
# [START json_enum_raw]
78108
model = genai.GenerativeModel("gemini-1.5-pro-latest")
@@ -92,5 +122,46 @@ def test_json_enum_raw(self):
92122
# [END json_enum_raw]
93123

94124

125+
def test_x_enum(self):
126+
# [START x_enum]
127+
import enum
128+
129+
class Choice(enum.Enum):
130+
PERCUSSION = "Percussion"
131+
STRING = "String"
132+
WOODWIND = "Woodwind"
133+
BRASS = "Brass"
134+
KEYBOARD = "Keyboard"
135+
136+
model = genai.GenerativeModel("gemini-1.5-pro-latest")
137+
138+
organ = genai.upload_file(media / "organ.jpg")
139+
result = model.generate_content(
140+
["What kind of instrument is this:", organ],
141+
generation_config=genai.GenerationConfig(
142+
response_mime_type="text/x.enum", response_schema=Choice
143+
),
144+
)
145+
print(result) # "Keyboard"
146+
# [END x_enum]
147+
148+
def test_x_enum_raw(self):
149+
# [START x_enum_raw]
150+
model = genai.GenerativeModel("gemini-1.5-pro-latest")
151+
152+
organ = genai.upload_file(media / "organ.jpg")
153+
result = model.generate_content(
154+
["What kind of instrument is this:", organ],
155+
generation_config=genai.GenerationConfig(
156+
response_mime_type="text/x.enum",
157+
response_schema={
158+
"type": "STRING",
159+
"enum": ["Percussion", "String", "Woodwind", "Brass", "Keyboard"],
160+
},
161+
),
162+
)
163+
print(result) # "Keyboard"
164+
# [END x_enum_raw]
165+
95166
if __name__ == "__main__":
96167
absltest.main()

0 commit comments

Comments
 (0)