Skip to content

Commit 630318b

Browse files
Rename json_mode.py and create function calling sample (#406)
* Rename json_mode.py and create function calling sample * Move functions inside test case * fix lightbot Change-Id: If4201ef6e0d0282aec685ed36b935b376f907856 * format Change-Id: Ib052095b489a28ca5a2fc7338c8fefd5ac0adbc5 * Function calling now passing tests * type:ignore * format Change-Id: I1fea8ebb0e7c7874fdac6743a35e9be37a67301c --------- Co-authored-by: Mark Daoust <[email protected]>
1 parent 81aaf35 commit 630318b

File tree

6 files changed

+81
-14
lines changed

6 files changed

+81
-14
lines changed

google/generativeai/types/content_types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,9 @@ def _schema_for_function(
369369
)
370370
)
371371
]
372-
schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters)
372+
schema = dict(name=f.__name__, description=f.__doc__)
373+
if parameters["properties"]:
374+
schema["parameters"] = parameters
373375

374376
return schema
375377

File renamed without changes.

samples/count_tokens.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
media = pathlib.Path(__file__).parents[1] / "third_party"
2121

2222

23-
24-
25-
2623
class UnitTests(absltest.TestCase):
2724
def test_tokens_text_only(self):
2825
# [START tokens_text_only]
@@ -84,8 +81,10 @@ def test_tokens_cached_content(self):
8481
def test_tokens_system_instruction(self):
8582
# [START tokens_system_instruction]
8683
document = genai.upload_file(path=media / "a11.txt")
87-
model = genai.GenerativeModel("models/gemini-1.5-flash-001",
88-
system_instruction="You are an expert analyzing transcripts. Give a summary of this document.")
84+
model = genai.GenerativeModel(
85+
"models/gemini-1.5-flash-001",
86+
system_instruction="You are an expert analyzing transcripts. Give a summary of this document.",
87+
)
8988
print(model.count_tokens(document))
9089
# [END tokens_system_instruction]
9190

@@ -95,25 +94,27 @@ def add(a: float, b: float):
9594
"""returns a + b."""
9695
return a + b
9796

98-
9997
def subtract(a: float, b: float):
10098
"""returns a - b."""
10199
return a - b
102100

103-
104101
def multiply(a: float, b: float):
105102
"""returns a * b."""
106103
return a * b
107104

108-
109105
def divide(a: float, b: float):
110106
"""returns a / b."""
111107
return a / b
112-
113-
model = genai.GenerativeModel("models/gemini-1.5-flash-001",
114-
tools=[add, subtract, multiply, divide])
115-
116-
print(model.count_tokens("I have 57 cats, each owns 44 mittens, how many mittens is that in total?"))
108+
109+
model = genai.GenerativeModel(
110+
"models/gemini-1.5-flash-001", tools=[add, subtract, multiply, divide]
111+
)
112+
113+
print(
114+
model.count_tokens(
115+
"I have 57 cats, each owns 44 mittens, how many mittens is that in total?"
116+
)
117+
)
117118
# [END tokens_tools]
118119

119120

samples/function_calling.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2023 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from absl.testing import absltest
16+
17+
import google.generativeai as genai
18+
19+
20+
class UnitTests(absltest.TestCase):
21+
def test_function_calling(self):
22+
# [START function_calling]
23+
def add(a: float, b: float):
24+
"""returns a + b."""
25+
return a + b
26+
27+
def subtract(a: float, b: float):
28+
"""returns a - b."""
29+
return a - b
30+
31+
def multiply(a: float, b: float):
32+
"""returns a * b."""
33+
return a * b
34+
35+
def divide(a: float, b: float):
36+
"""returns a / b."""
37+
return a / b
38+
39+
model = genai.GenerativeModel(
40+
model_name="gemini-1.5-flash", tools=[add, subtract, multiply, divide]
41+
)
42+
chat = model.start_chat(enable_automatic_function_calling=True)
43+
response = chat.send_message(
44+
"I have 57 cats, each owns 44 mittens, how many mittens is that in total?"
45+
)
46+
print(response.text)
47+
# [END function_calling]
48+
49+
50+
if __name__ == "__main__":
51+
absltest.main()

samples/text_generation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def test_text_gen_text_only_prompt_streaming(self):
4141
def test_text_gen_multimodal_one_image_prompt(self):
4242
# [START text_gen_multimodal_one_image_prompt]
4343
import PIL
44+
4445
model = genai.GenerativeModel("gemini-1.5-flash")
4546
organ = PIL.Image.open(media / "organ.jpg")
4647
response = model.generate_content(["Tell me about this instrument", organ])
@@ -50,6 +51,7 @@ def test_text_gen_multimodal_one_image_prompt(self):
5051
def test_text_gen_multimodal_one_image_prompt_streaming(self):
5152
# [START text_gen_multimodal_one_image_prompt_streaming]
5253
import PIL
54+
5355
model = genai.GenerativeModel("gemini-1.5-flash")
5456
organ = PIL.Image.open(media / "organ.jpg")
5557
response = model.generate_content(["Tell me about this instrument", organ], stream=True)
@@ -61,6 +63,7 @@ def test_text_gen_multimodal_one_image_prompt_streaming(self):
6163
def test_text_gen_multimodal_multi_image_prompt(self):
6264
# [START text_gen_multimodal_multi_image_prompt]
6365
import PIL
66+
6467
model = genai.GenerativeModel("gemini-1.5-flash")
6568
organ = PIL.Image.open(media / "organ.jpg")
6669
cajun_instrument = PIL.Image.open(media / "Cajun_instruments.jpg")
@@ -73,6 +76,7 @@ def test_text_gen_multimodal_multi_image_prompt(self):
7376
def test_text_gen_multimodal_multi_image_prompt_streaming(self):
7477
# [START text_gen_multimodal_multi_image_prompt_streaming]
7578
import PIL
79+
7680
model = genai.GenerativeModel("gemini-1.5-flash")
7781
organ = PIL.Image.open(media / "organ.jpg")
7882
cajun_instrument = PIL.Image.open(media / "Cajun_instruments.jpg")

tests/test_content.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,15 @@ def test_to_tools(self, tools):
378378

379379
self.assertEqual(tools, expected)
380380

381+
def test_empty_function(self):
382+
def no_args():
383+
print("hello")
384+
385+
fd = content_types.to_function_library(no_args).to_proto()[0] # type: ignore
386+
fd = type(fd).to_dict(fd, including_default_value_fields=False)
387+
# parameters are not set.
388+
self.assertEqual({"function_declarations": [{"name": "no_args"}]}, fd)
389+
381390
@parameterized.named_parameters(
382391
["string", "code_execution"],
383392
["proto_object", protos.CodeExecution()],

0 commit comments

Comments
 (0)