Skip to content

Commit 3a9c286

Browse files
Copy test cases over for new responder.py (#303)
* Moving code related to function calling into a separate file. * Formatted with black . * Remove unused imports * Add _generate_schema to responder.py * Add imports for _generate_schema * Add license * add test cases for responder.py * fix indent Change-Id: I55e2d72d7f27eb2c0c89285de7b6fab17d9a20f0 --------- Co-authored-by: Mark Daoust <[email protected]>
1 parent f63e15d commit 3a9c286

File tree

2 files changed

+272
-20
lines changed

2 files changed

+272
-20
lines changed

google/generativeai/responder.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -147,26 +147,26 @@ def __init__(self, *, name: str, description: str, parameters: dict[str, Any] |
147147
name=name, description=description, parameters=_rename_schema_fields(parameters)
148148
)
149149

150-
@property
151-
def name(self) -> str:
152-
return self._proto.name
153-
154-
@property
155-
def description(self) -> str:
156-
return self._proto.description
157-
158-
@property
159-
def parameters(self) -> glm.Schema:
160-
return self._proto.parameters
161-
162-
@classmethod
163-
def from_proto(cls, proto) -> FunctionDeclaration:
164-
self = cls(name="", description="", parameters={})
165-
self._proto = proto
166-
return self
167-
168-
def to_proto(self) -> glm.FunctionDeclaration:
169-
return self._proto
150+
@property
151+
def name(self) -> str:
152+
return self._proto.name
153+
154+
@property
155+
def description(self) -> str:
156+
return self._proto.description
157+
158+
@property
159+
def parameters(self) -> glm.Schema:
160+
return self._proto.parameters
161+
162+
@classmethod
163+
def from_proto(cls, proto) -> FunctionDeclaration:
164+
self = cls(name="", description="", parameters={})
165+
self._proto = proto
166+
return self
167+
168+
def to_proto(self) -> glm.FunctionDeclaration:
169+
return self._proto
170170

171171
@staticmethod
172172
def from_function(function: Callable[..., Any], descriptions: dict[str, str] | None = None):

tests/test_responder.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2023 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import pathlib
16+
from typing import Any
17+
18+
from absl.testing import absltest
19+
from absl.testing import parameterized
20+
import google.ai.generativelanguage as glm
21+
from google.generativeai import responder
22+
import IPython.display
23+
import PIL.Image
24+
25+
HERE = pathlib.Path(__file__).parent
26+
TEST_PNG_PATH = HERE / "test_img.png"
27+
TEST_PNG_URL = "https://storage.googleapis.com/generativeai-downloads/data/test_img.png"
28+
TEST_PNG_DATA = TEST_PNG_PATH.read_bytes()
29+
30+
TEST_JPG_PATH = HERE / "test_img.jpg"
31+
TEST_JPG_URL = "https://storage.googleapis.com/generativeai-downloads/data/test_img.jpg"
32+
TEST_JPG_DATA = TEST_JPG_PATH.read_bytes()
33+
34+
35+
# simple test function
36+
def datetime():
37+
"Returns the current UTC date and time."
38+
39+
40+
class UnitTests(parameterized.TestCase):
41+
@parameterized.named_parameters(
42+
[
43+
"FunctionLibrary",
44+
responder.FunctionLibrary(
45+
tools=glm.Tool(
46+
function_declarations=[
47+
glm.FunctionDeclaration(
48+
name="datetime", description="Returns the current UTC date and time."
49+
)
50+
]
51+
)
52+
),
53+
],
54+
[
55+
"IterableTool-Tool",
56+
[
57+
responder.Tool(
58+
function_declarations=[
59+
glm.FunctionDeclaration(
60+
name="datetime", description="Returns the current UTC date and time."
61+
)
62+
]
63+
)
64+
],
65+
],
66+
[
67+
"IterableTool-glm.Tool",
68+
[
69+
glm.Tool(
70+
function_declarations=[
71+
glm.FunctionDeclaration(
72+
name="datetime",
73+
description="Returns the current UTC date and time.",
74+
)
75+
]
76+
)
77+
],
78+
],
79+
[
80+
"IterableTool-ToolDict",
81+
[
82+
dict(
83+
function_declarations=[
84+
dict(
85+
name="datetime",
86+
description="Returns the current UTC date and time.",
87+
)
88+
]
89+
)
90+
],
91+
],
92+
[
93+
"IterableTool-IterableFD",
94+
[
95+
[
96+
glm.FunctionDeclaration(
97+
name="datetime",
98+
description="Returns the current UTC date and time.",
99+
)
100+
]
101+
],
102+
],
103+
[
104+
"IterableTool-FD",
105+
[
106+
glm.FunctionDeclaration(
107+
name="datetime",
108+
description="Returns the current UTC date and time.",
109+
)
110+
],
111+
],
112+
[
113+
"Tool",
114+
responder.Tool(
115+
function_declarations=[
116+
glm.FunctionDeclaration(
117+
name="datetime", description="Returns the current UTC date and time."
118+
)
119+
]
120+
),
121+
],
122+
[
123+
"glm.Tool",
124+
glm.Tool(
125+
function_declarations=[
126+
glm.FunctionDeclaration(
127+
name="datetime", description="Returns the current UTC date and time."
128+
)
129+
]
130+
),
131+
],
132+
[
133+
"ToolDict",
134+
dict(
135+
function_declarations=[
136+
dict(name="datetime", description="Returns the current UTC date and time.")
137+
]
138+
),
139+
],
140+
[
141+
"IterableFD-FD",
142+
[
143+
responder.FunctionDeclaration(
144+
name="datetime", description="Returns the current UTC date and time."
145+
)
146+
],
147+
],
148+
[
149+
"IterableFD-CFD",
150+
[
151+
responder.CallableFunctionDeclaration(
152+
name="datetime",
153+
description="Returns the current UTC date and time.",
154+
function=datetime,
155+
)
156+
],
157+
],
158+
[
159+
"IterableFD-dict",
160+
[dict(name="datetime", description="Returns the current UTC date and time.")],
161+
],
162+
["IterableFD-Callable", [datetime]],
163+
[
164+
"FD",
165+
responder.FunctionDeclaration(
166+
name="datetime", description="Returns the current UTC date and time."
167+
),
168+
],
169+
[
170+
"CFD",
171+
responder.CallableFunctionDeclaration(
172+
name="datetime",
173+
description="Returns the current UTC date and time.",
174+
function=datetime,
175+
),
176+
],
177+
[
178+
"glm.FD",
179+
glm.FunctionDeclaration(
180+
name="datetime", description="Returns the current UTC date and time."
181+
),
182+
],
183+
["dict", dict(name="datetime", description="Returns the current UTC date and time.")],
184+
["Callable", datetime],
185+
)
186+
def test_to_tools(self, tools):
187+
function_library = responder.to_function_library(tools)
188+
if function_library is None:
189+
raise ValueError("This shouldn't happen")
190+
tools = function_library.to_proto()
191+
192+
tools = type(tools[0]).to_dict(tools[0])
193+
tools["function_declarations"][0].pop("parameters", None)
194+
195+
expected = dict(
196+
function_declarations=[
197+
dict(name="datetime", description="Returns the current UTC date and time.")
198+
]
199+
)
200+
201+
self.assertEqual(tools, expected)
202+
203+
def test_two_fun_is_one_tool(self):
204+
def a():
205+
pass
206+
207+
def b():
208+
pass
209+
210+
function_library = responder.to_function_library([a, b])
211+
if function_library is None:
212+
raise ValueError("This shouldn't happen")
213+
tools = function_library.to_proto()
214+
215+
self.assertLen(tools, 1)
216+
self.assertLen(tools[0].function_declarations, 2)
217+
218+
@parameterized.named_parameters(
219+
["int", int, glm.Schema(type=glm.Type.INTEGER)],
220+
["float", float, glm.Schema(type=glm.Type.NUMBER)],
221+
["str", str, glm.Schema(type=glm.Type.STRING)],
222+
[
223+
"list",
224+
list[str],
225+
glm.Schema(
226+
type=glm.Type.ARRAY,
227+
items=glm.Schema(type=glm.Type.STRING),
228+
),
229+
],
230+
[
231+
"list-list-int",
232+
list[list[int]],
233+
glm.Schema(
234+
type=glm.Type.ARRAY,
235+
items=glm.Schema(
236+
glm.Schema(
237+
type=glm.Type.ARRAY,
238+
items=glm.Schema(type=glm.Type.INTEGER),
239+
),
240+
),
241+
),
242+
],
243+
["dict", dict, glm.Schema(type=glm.Type.OBJECT)],
244+
["dict-str-any", dict[str, Any], glm.Schema(type=glm.Type.OBJECT)],
245+
)
246+
def test_auto_schema(self, annotation, expected):
247+
def fun(a: annotation):
248+
pass
249+
250+
cfd = responder.FunctionDeclaration.from_function(fun)
251+
got = cfd.parameters.properties["a"]
252+
self.assertEqual(got, expected)

0 commit comments

Comments
 (0)