Skip to content

Commit ba6b439

Browse files
mayureshagashe2105markmcdelsiel23
authored
System instruction (#270)
* fix arg name: system_instructions to system_instruction proto def for glm.GenerateContentRequest lists system_instruction as singular * import TypeDict from typing_extensions * Update glm dependecy to use 0.6.1 to support files, SI, tool_config * Handle function_calling_mode when passed as a dict with allowed_func_names * format * Scope TypedDict test to package directory * De-pluralise 'instructions' * System instructions tests and blacken * format --------- Co-authored-by: Mark McDonald <[email protected]> Co-authored-by: Elsie L <[email protected]>
1 parent 0778d56 commit ba6b439

File tree

5 files changed

+36
-9
lines changed

5 files changed

+36
-9
lines changed

google/generativeai/answer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
import dataclasses
1818
from collections.abc import Iterable
1919
import itertools
20-
from typing import Any, Iterable, Union, Mapping, Optional, TypedDict
20+
from typing import Any, Iterable, Union, Mapping, Optional
21+
from typing_extensions import TypedDict
2122

2223
import google.ai.generativelanguage as glm
2324

google/generativeai/generative_models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
generation_config: generation_types.GenerationConfigType | None = None,
7575
tools: content_types.FunctionLibraryType | None = None,
7676
tool_config: content_types.ToolConfigType | None = None,
77-
system_instructions: content_types.ContentType | None = None,
77+
system_instruction: content_types.ContentType | None = None,
7878
):
7979
if "/" not in model_name:
8080
model_name = "models/" + model_name
@@ -90,10 +90,10 @@ def __init__(
9090
else:
9191
self._tool_config = content_types.to_tool_config(tool_config)
9292

93-
if system_instructions is None:
94-
self._system_instructions = None
93+
if system_instruction is None:
94+
self._system_instruction = None
9595
else:
96-
self._system_instructions = content_types.to_content(system_instructions)
96+
self._system_instruction = content_types.to_content(system_instruction)
9797

9898
self._client = None
9999
self._async_client = None
@@ -155,7 +155,7 @@ def _prepare_request(
155155
safety_settings=merged_ss,
156156
tools=tools_lib,
157157
tool_config=tool_config,
158-
system_instructions=self._system_instructions,
158+
system_instruction=self._system_instruction,
159159
)
160160

161161
def _get_tools_lib(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_version():
4242
release_status = "Development Status :: 5 - Production/Stable"
4343

4444
dependencies = [
45-
"google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py.tar.gz",
45+
"google-ai-generativelanguage==0.6.1",
4646
"google-api-core",
4747
"google-api-python-client",
4848
"google-auth>=2.15.0", # 2.15 adds API key auth support

tests/test_generative_models.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,16 @@
2121
TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes()
2222

2323

24+
def simple_part(text: str) -> glm.Content:
25+
return glm.Content({"parts": [{"text": text}]})
26+
27+
28+
def iter_part(texts: Iterable[str]) -> glm.Content:
29+
return glm.Content({"parts": [{"text": t} for t in texts]})
30+
31+
2432
def simple_response(text: str) -> glm.GenerateContentResponse:
25-
return glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": text}]}}]})
33+
return glm.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]})
2634

2735

2836
class CUJTests(parameterized.TestCase):
@@ -605,6 +613,24 @@ def test_tools(self):
605613
self.assertLen(obr.tools, 1)
606614
self.assertEqual(type(obr.tools[0]).to_dict(obr.tools[0]), tools)
607615

616+
@parameterized.named_parameters(
617+
["bare_str", "talk like a pirate", simple_part("talk like a pirate")],
618+
[
619+
"part_dict",
620+
{"parts": [{"text": "talk like a pirate"}]},
621+
simple_part("talk like a pirate"),
622+
],
623+
["part_list", ["talk like:", "a pirate"], iter_part(["talk like:", "a pirate"])],
624+
)
625+
def test_system_instruction(self, instruction, expected_instr):
626+
self.responses["generate_content"] = [simple_response("echo echo")]
627+
model = generative_models.GenerativeModel("gemini-pro", system_instruction=instruction)
628+
629+
_ = model.generate_content("test")
630+
631+
[req] = self.observed_requests
632+
self.assertEqual(req.system_instruction, expected_instr)
633+
608634
@parameterized.named_parameters(
609635
["basic", "Hello"],
610636
["list", ["Hello"]],

tests/test_typing_extensions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class TypingExtensionsTests(absltest.TestCase):
3333

3434
def test_no_typing_typed_dict(self):
3535
root = pathlib.Path(__file__).parent.parent
36-
for fpath in root.rglob("*.py"):
36+
for fpath in (root / "google").rglob("*.py"):
3737
source = fpath.read_text()
3838
if match := TYPING_RE.search(source):
3939
raise ValueError(

0 commit comments

Comments
 (0)