Skip to content

Commit 63a140b

Browse files
Code execution (#396)
* Modify _make_tool to accept code_execution proto object * Make adding code execution compatible with current tests * Update test_content.py to test code_execution * Test cases running and passing for code execution updates * Updated _join_contents to handle executable_code and code_execution_results * Updated .text field to include executable code and code execution results * Update versions. Change-Id: I4e2c9081a9de9105bfc1d9d83769d98faee4d18e * format Change-Id: I38de8c9fe434785608ad5e04a2234bd584e2de03 --------- Co-authored-by: Mark Daoust <[email protected]>
1 parent 3e281e7 commit 63a140b

File tree

6 files changed

+183
-32
lines changed

6 files changed

+183
-32
lines changed

google/generativeai/types/content_types.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -623,24 +623,40 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F
623623
class Tool:
624624
"""A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects."""
625625

626-
def __init__(self, function_declarations: Iterable[FunctionDeclarationType]):
626+
def __init__(
627+
self,
628+
function_declarations: Iterable[FunctionDeclarationType] | None = None,
629+
code_execution: protos.CodeExecution | None = None,
630+
):
627631
# The main path doesn't use this but is seems useful.
628-
self._function_declarations = [_make_function_declaration(f) for f in function_declarations]
629-
self._index = {}
630-
for fd in self._function_declarations:
631-
name = fd.name
632-
if name in self._index:
633-
raise ValueError("")
634-
self._index[fd.name] = fd
632+
if function_declarations:
633+
self._function_declarations = [
634+
_make_function_declaration(f) for f in function_declarations
635+
]
636+
self._index = {}
637+
for fd in self._function_declarations:
638+
name = fd.name
639+
if name in self._index:
640+
raise ValueError("")
641+
self._index[fd.name] = fd
642+
else:
643+
# Consistent fields
644+
self._function_declarations = []
645+
self._index = {}
635646

636647
self._proto = protos.Tool(
637-
function_declarations=[_encode_fd(fd) for fd in self._function_declarations]
648+
function_declarations=[_encode_fd(fd) for fd in self._function_declarations],
649+
code_execution=code_execution,
638650
)
639651

640652
@property
641653
def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]:
642654
return self._function_declarations
643655

656+
@property
657+
def code_execution(self) -> protos.CodeExecution:
658+
return self._proto.code_execution
659+
644660
def __getitem__(
645661
self, name: str | protos.FunctionCall
646662
) -> FunctionDeclaration | protos.FunctionDeclaration:
@@ -673,13 +689,24 @@ def _make_tool(tool: ToolType) -> Tool:
673689
if isinstance(tool, Tool):
674690
return tool
675691
elif isinstance(tool, protos.Tool):
676-
return Tool(function_declarations=tool.function_declarations)
692+
if "code_execution" in tool:
693+
code_execution = tool.code_execution
694+
else:
695+
code_execution = None
696+
return Tool(function_declarations=tool.function_declarations, code_execution=code_execution)
677697
elif isinstance(tool, dict):
678-
if "function_declarations" in tool:
698+
if "function_declarations" in tool or "code_execution" in tool:
679699
return Tool(**tool)
680700
else:
681701
fd = tool
682702
return Tool(function_declarations=[protos.FunctionDeclaration(**fd)])
703+
elif isinstance(tool, str):
704+
if tool.lower() == "code_execution":
705+
return Tool(code_execution=protos.CodeExecution())
706+
else:
707+
raise ValueError("The only string that can be passed as a tool is 'code_execution'.")
708+
elif isinstance(tool, protos.CodeExecution):
709+
return Tool(code_execution=tool)
683710
elif isinstance(tool, Iterable):
684711
return Tool(function_declarations=tool)
685712
else:
@@ -734,7 +761,12 @@ def to_proto(self):
734761

735762

736763
def _make_tools(tools: ToolsType) -> list[Tool]:
737-
if isinstance(tools, Iterable) and not isinstance(tools, Mapping):
764+
if isinstance(tools, str):
765+
if tools.lower() == "code_execution":
766+
return [_make_tool(tools)]
767+
else:
768+
raise ValueError("The only string that can be passed as a tool is 'code_execution'.")
769+
elif isinstance(tools, Iterable) and not isinstance(tools, Mapping):
738770
tools = [_make_tool(t) for t in tools]
739771
if len(tools) > 1 and all(len(t.function_declarations) == 1 for t in tools):
740772
# flatten into a single tool.

google/generativeai/types/generation_types.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -261,26 +261,50 @@ def _join_contents(contents: Iterable[protos.Content]):
261261
for content in contents:
262262
parts.extend(content.parts)
263263

264-
merged_parts = [parts.pop(0)]
265-
for part in parts:
266-
if not merged_parts[-1].text:
267-
merged_parts.append(part)
264+
merged_parts = []
265+
last = parts[0]
266+
for part in parts[1:]:
267+
if "text" in last and "text" in part:
268+
last = protos.Part(text=last.text + part.text)
268269
continue
269270

270-
if not part.text:
271-
merged_parts.append(part)
271+
# Can we merge the new thing into last?
272+
# If not, put last in list of parts, and new thing becomes last
273+
if "executable_code" in last and "executable_code" in part:
274+
last = protos.Part(
275+
executable_code=_join_executable_code(last.executable_code, part.executable_code)
276+
)
272277
continue
273278

274-
merged_part = protos.Part(merged_parts[-1])
275-
merged_part.text += part.text
276-
merged_parts[-1] = merged_part
279+
if "code_execution_result" in last and "code_execution_result" in part:
280+
last = protos.Part(
281+
code_execution_result=_join_code_execution_result(
282+
last.code_execution_result, part.code_execution_result
283+
)
284+
)
285+
continue
286+
287+
merged_parts.append(last)
288+
last = part
289+
290+
merged_parts.append(last)
277291

278292
return protos.Content(
279293
role=role,
280294
parts=merged_parts,
281295
)
282296

283297

298+
def _join_executable_code(code_1, code_2):
299+
return protos.ExecutableCode(language=code_1.language, code=code_1.code + code_2.code)
300+
301+
302+
def _join_code_execution_result(result_1, result_2):
303+
return protos.CodeExecutionResult(
304+
outcome=result_2.outcome, output=result_1.output + result_2.output
305+
)
306+
307+
284308
def _join_candidates(candidates: Iterable[protos.Candidate]):
285309
candidates = tuple(candidates)
286310

@@ -413,13 +437,35 @@ def text(self):
413437
"Invalid operation: The `response.text` quick accessor requires the response to contain a valid `Part`, "
414438
"but none were returned. Please check the `candidate.safety_ratings` to determine if the response was blocked."
415439
)
416-
if len(parts) != 1 or "text" not in parts[0]:
417-
raise ValueError(
418-
"Invalid operation: The `response.text` quick accessor requires a simple (single-`Part`) text response. "
419-
"This response is not simple text. Please use the `result.parts` accessor or the full "
420-
"`result.candidates[index].content.parts` lookup instead."
421-
)
422-
return parts[0].text
440+
441+
texts = []
442+
for part in parts:
443+
if "text" in part:
444+
texts.append(part.text)
445+
continue
446+
if "executable_code" in part:
447+
language = part.executable_code.language.name.lower()
448+
if language == "language_unspecified":
449+
language = ""
450+
else:
451+
language = f" {language}"
452+
texts.extend([f"```{language}", part.executable_code.code, "```"])
453+
continue
454+
if "code_execution_result" in part:
455+
outcome_result = part.code_execution_result.outcome.name.lower().replace(
456+
"outcome_", ""
457+
)
458+
if outcome_result == "ok" or outcome_result == "unspecified":
459+
outcome_result = ""
460+
else:
461+
outcome_result = f" {outcome_result}"
462+
texts.extend([f"```{outcome_result}", part.code_execution_result.output, "```"])
463+
continue
464+
465+
part_type = protos.Part.pb(part).whichOneof("data")
466+
raise ValueError(f"Could not convert `part.{part_type}` to text.")
467+
468+
return "\n".join(texts)
423469

424470
@property
425471
def prompt_feedback(self):

google/generativeai/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
__version__ = "0.7.0"
17+
__version__ = "0.7.1"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_version():
4242
release_status = "Development Status :: 5 - Production/Stable"
4343

4444
dependencies = [
45-
"google-ai-generativelanguage==0.6.5",
45+
"google-ai-generativelanguage==0.6.6",
4646
"google-api-core",
4747
"google-api-python-client",
4848
"google-auth>=2.15.0", # 2.15 adds API key auth support

tests/test_content.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import dataclasses
1616
import pathlib
1717
import typing_extensions
18-
from typing import Any, Union
18+
from typing import Any, Union, Iterable
1919

2020
from absl.testing import absltest
2121
from absl.testing import parameterized
@@ -367,7 +367,7 @@ def test_to_tools(self, tools):
367367
raise ValueError("This shouldn't happen")
368368
tools = function_library.to_proto()
369369

370-
tools = type(tools[0]).to_dict(tools[0])
370+
tools = type(tools[0]).to_dict(tools[0], including_default_value_fields=False)
371371
tools["function_declarations"][0].pop("parameters", None)
372372

373373
expected = dict(
@@ -378,6 +378,24 @@ def test_to_tools(self, tools):
378378

379379
self.assertEqual(tools, expected)
380380

381+
@parameterized.named_parameters(
382+
["string", "code_execution"],
383+
["proto_object", protos.CodeExecution()],
384+
["proto_passed_in", protos.Tool(code_execution=protos.CodeExecution())],
385+
["empty_dictionary", {"code_execution": {}}],
386+
["string_list", ["code_execution"]],
387+
["proto_object_list", [protos.CodeExecution()]],
388+
["proto_passed_in_list", [protos.Tool(code_execution=protos.CodeExecution())]],
389+
["empty_dictionary_list", [{"code_execution": {}}]],
390+
)
391+
def test_code_execution(self, tools):
392+
if isinstance(tools, Iterable):
393+
t = content_types._make_tools(tools)
394+
self.assertIsInstance(t[0].code_execution, protos.CodeExecution)
395+
else:
396+
t = content_types._make_tool(tools) # Pass code execution into tools
397+
self.assertIsInstance(t.code_execution, protos.CodeExecution)
398+
381399
def test_two_fun_is_one_tool(self):
382400
def a():
383401
pass

tests/test_generation.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,61 @@ def test_join_contents(self):
124124

125125
self.assertEqual(expected, type(result).to_dict(result))
126126

127+
def test_join_parts(self):
128+
contents = [
129+
protos.Content(role="assistant", parts=[protos.Part(text="A")]),
130+
protos.Content(role="assistant", parts=[protos.Part(text="B")]),
131+
protos.Content(role="assistant", parts=[protos.Part(executable_code={"code": "C"})]),
132+
protos.Content(role="assistant", parts=[protos.Part(executable_code={"code": "D"})]),
133+
protos.Content(
134+
role="assistant", parts=[protos.Part(code_execution_result={"output": "E"})]
135+
),
136+
protos.Content(
137+
role="assistant", parts=[protos.Part(code_execution_result={"output": "F"})]
138+
),
139+
protos.Content(role="assistant", parts=[protos.Part(text="G")]),
140+
protos.Content(role="assistant", parts=[protos.Part(text="H")]),
141+
]
142+
g = generation_types._join_contents(contents=contents)
143+
expected = protos.Content(
144+
role="assistant",
145+
parts=[
146+
protos.Part(text="AB"),
147+
protos.Part(executable_code={"code": "CD"}),
148+
protos.Part(code_execution_result={"output": "EF"}),
149+
protos.Part(text="GH"),
150+
],
151+
)
152+
self.assertEqual(expected, g)
153+
154+
def test_code_execution_text(self):
155+
content = protos.Content(
156+
role="assistant",
157+
parts=[
158+
protos.Part(text="AB"),
159+
protos.Part(executable_code={"language": "PYTHON", "code": "CD"}),
160+
protos.Part(code_execution_result={"outcome": "OUTCOME_OK", "output": "EF"}),
161+
protos.Part(text="GH"),
162+
],
163+
)
164+
response = generation_types.GenerateContentResponse(
165+
done=True,
166+
iterator=None,
167+
result=protos.GenerateContentResponse({"candidates": [{"content": content}]}),
168+
)
169+
expected = textwrap.dedent(
170+
"""\
171+
AB
172+
``` python
173+
CD
174+
```
175+
```
176+
EF
177+
```
178+
GH"""
179+
)
180+
self.assertEqual(expected, response.text)
181+
127182
def test_many_join_contents(self):
128183
import string
129184

0 commit comments

Comments
 (0)