Skip to content

Commit e56fe19

Browse files
committed
test generate_text safety settings, filters and feedback.
1 parent 7a7afb2 commit e56fe19

File tree

11 files changed

+171
-55
lines changed

11 files changed

+171
-55
lines changed

google/generativeai/discuss.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,9 @@ def __init__(self, **kwargs):
392392
@set_doc(discuss_types.ChatResponse.last.__doc__)
393393
def last(self) -> Optional[str]:
394394
if self.messages[-1]:
395-
return self.messages[-1]["content"]
395+
return self.messages[-1]["content"]
396396
else:
397-
return None
397+
return None
398398

399399
@last.setter
400400
def last(self, message: discuss_types.MessageOptions):
@@ -410,9 +410,11 @@ def reply(
410410
f"reply can't be called on an async client, use reply_async instead."
411411
)
412412
if self.last is None:
413-
raise ValueError('The last response from the model did not return any candidates.\n'
414-
'Check the `.filters` attribute to see why the responses were filtered:\n'
415-
f'{self.filters}')
413+
raise ValueError(
414+
"The last response from the model did not return any candidates.\n"
415+
"Check the `.filters` attribute to see why the responses were filtered:\n"
416+
f"{self.filters}"
417+
)
416418

417419
request = self.to_dict()
418420
request.pop("candidates")
@@ -438,9 +440,6 @@ async def reply_async(
438440
request = _make_generate_message_request(**request)
439441
return await _generate_response_async(request=request, client=self._client)
440442

441-
def _convert_filters_to_enums(filters):
442-
for f in filters:
443-
f['reason'] = safety_types.BlockedReason(f['reason'])
444443

445444
def _build_chat_response(
446445
request: glm.GenerateMessageRequest,
@@ -456,7 +455,7 @@ def _build_chat_response(
456455
response = type(response).to_dict(response)
457456
response.pop("messages")
458457

459-
_convert_filters_to_enums(response['filters'])
458+
safety_types.convert_filters_to_enums(response["filters"])
460459

461460
if response["candidates"]:
462461
last = response["candidates"][0]

google/generativeai/docstring_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
1617
def strip_oneof(docstring):
1718
lines = docstring.splitlines()
1819
lines = [line for line in lines if ".. _oneof:" not in line]
1920
lines = [line for line in lines if "This field is a member of `oneof`_" not in line]
20-
return "\n".join(lines)
21+
return "\n".join(lines)

google/generativeai/text.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _make_generate_text_request(
4545
max_output_tokens: Optional[int] = None,
4646
top_p: Optional[int] = None,
4747
top_k: Optional[int] = None,
48-
safety_settings: Optional[List[safety_types.SafetySetting]] = None,
48+
safety_settings: Optional[List[safety_types.SafetySettingDict]] = None,
4949
stop_sequences: Union[str, Iterable[str]] = None,
5050
) -> glm.GenerateTextRequest:
5151
model = model_types.make_model_name(model)
@@ -77,7 +77,7 @@ def generate_text(
7777
max_output_tokens: Optional[int] = None,
7878
top_p: Optional[float] = None,
7979
top_k: Optional[float] = None,
80-
safety_settings: Optional[Iterable[safety.SafetySetting]] = None,
80+
safety_settings: Optional[Iterable[safety.SafetySettingDict]] = None,
8181
stop_sequences: Union[str, Iterable[str]] = None,
8282
client: Optional[glm.TextServiceClient] = None,
8383
) -> text_types.Completion:
@@ -159,6 +159,9 @@ def _generate_response(
159159
response = client.generate_text(request)
160160
response = type(response).to_dict(response)
161161

162+
safety_types.convert_filters_to_enums(response["filters"])
163+
safety_types.convert_safety_feedback_to_enums(response["safety_feedback"])
164+
162165
return Completion(_client=client, **response)
163166

164167

google/generativeai/types/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@
2424
del model_types
2525
del text_types
2626
del citation_types
27-
del safety_types
27+
del safety_types

google/generativeai/types/discuss_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class ChatResponse(abc.ABC):
162162
candidates: List[MessageDict]
163163
top_p: Optional[float] = None
164164
top_k: Optional[float] = None
165-
filters: List[safety_types.ContentFilter]
165+
filters: List[safety_types.ContentFilterDict]
166166

167167
@property
168168
@abc.abstractmethod

google/generativeai/types/safety_types.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
"HarmProbability",
2424
"HarmBlockThreshold",
2525
"BlockedReason",
26-
"ContentFilter",
26+
"ContentFilterDict",
2727
"SafetyRatingDict",
28-
"SafetySetting",
28+
"SafetySettingDict",
2929
"SafetyFeedbackDict",
3030
]
3131

@@ -36,29 +36,50 @@
3636
BlockedReason = glm.ContentFilter.BlockedReason
3737

3838

39-
class ContentFilter(TypedDict):
39+
class ContentFilterDict(TypedDict):
4040
reason: BlockedReason
4141
message: str
4242

4343
__doc__ = docstring_utils.strip_oneof(glm.ContentFilter.__doc__)
4444

4545

46+
def convert_filters_to_enums(filters):
47+
for f in filters:
48+
f["reason"] = BlockedReason(f["reason"])
49+
50+
4651
class SafetyRatingDict(TypedDict):
4752
category: HarmCategory
4853
probability: HarmProbability
4954

5055
__doc__ = docstring_utils.strip_oneof(glm.SafetyRating.__doc__)
5156

5257

53-
class SafetySetting(TypedDict):
58+
def convert_rating_to_enum(setting):
59+
setting["category"] = HarmCategory(setting["category"])
60+
setting["probability"] = HarmProbability(setting["probability"])
61+
62+
63+
class SafetySettingDict(TypedDict):
5464
category: HarmCategory
5565
threshold: HarmBlockThreshold
5666

5767
__doc__ = docstring_utils.strip_oneof(glm.SafetySetting.__doc__)
5868

5969

70+
def convert_setting_to_enum(setting):
71+
setting["category"] = HarmCategory(setting["category"])
72+
setting["threshold"] = HarmBlockThreshold(setting["threshold"])
73+
74+
6075
class SafetyFeedbackDict(TypedDict):
6176
rating: SafetyRatingDict
62-
setting: SafetySetting
77+
setting: SafetySettingDict
6378

6479
__doc__ = docstring_utils.strip_oneof(glm.SafetyFeedback.__doc__)
80+
81+
82+
def convert_safety_feedback_to_enums(safety_feedback):
83+
for sf in safety_feedback:
84+
convert_rating_to_enum(sf["rating"])
85+
convert_setting_to_enum(sf["setting"])

google/generativeai/types/text_types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class TextCompletion(TypedDict, total=False):
3232

3333
@dataclasses.dataclass(init=False)
3434
class Completion(abc.ABC):
35-
"""The result of the `1 given a prompt from the model.
35+
"""The result returned by `generativeai.generate_text`.
3636
3737
Use `GenerateTextResponse.candidates` to access all the completions generated by the model.
3838
@@ -43,9 +43,10 @@ class Completion(abc.ABC):
4343
Either Unspecified, Safety, or Other. See `types.ContentFilter`.
4444
safety_feedback: Indicates which safety settings blocked content in this result.
4545
"""
46+
4647
candidates: List[TextCompletion]
4748
result: Optional[str]
48-
filters: Optional[list[safety_types.ContentFilter]]
49+
filters: Optional[list[safety_types.ContentFilterDict]]
4950
safety_feedback: Optional[list[safety_types.SafetyFeedbackDict]]
5051

5152
def to_dict(self) -> Dict[str, Any]:

setup.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@
3434
else:
3535
release_status = "Development Status :: 5 - Production/Stable"
3636

37-
dependencies = [
38-
"google-ai-generativelanguage==0.2.0"
39-
]
37+
dependencies = ["google-ai-generativelanguage==0.2.0"]
4038

4139
extras_require = {
4240
"dev": [

tests/test_discuss.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from google.generativeai import discuss
2222
from google.generativeai import client
2323
import google.generativeai as genai
24+
from google.generativeai.types import safety_types
25+
2426
from absl.testing import absltest
2527
from absl.testing import parameterized
2628

@@ -36,12 +38,12 @@ def setUp(self):
3638
self.observed_request = None
3739

3840
self.mock_response = glm.GenerateMessageResponse(
39-
candidates=[
40-
glm.Message(content="a", author="1"),
41-
glm.Message(content="b", author="1"),
42-
glm.Message(content="c", author="1"),
43-
],
44-
)
41+
candidates=[
42+
glm.Message(content="a", author="1"),
43+
glm.Message(content="b", author="1"),
44+
glm.Message(content="c", author="1"),
45+
],
46+
)
4547

4648
def fake_generate_message(
4749
request: glm.GenerateMessageRequest,
@@ -274,32 +276,39 @@ def test_reply(self, kwargs):
274276
response = response.reply("again")
275277

276278
def test_receive_and_reply_with_filters(self):
277-
278279
self.mock_response = mock_response = glm.GenerateMessageResponse(
279280
candidates=[glm.Message(content="a", author="1")],
280281
filters=[
281-
glm.ContentFilter(reason=glm.ContentFilter.BlockedReason.SAFETY, message='unsafe'),
282-
glm.ContentFilter(reason=glm.ContentFilter.BlockedReason.OTHER),]
282+
glm.ContentFilter(
283+
reason=safety_types.BlockedReason.SAFETY, message="unsafe"
284+
),
285+
glm.ContentFilter(reason=safety_types.BlockedReason.OTHER),
286+
],
283287
)
284288
response = discuss.chat(messages="do filters work?")
285289

286290
filters = response.filters
287291
self.assertLen(filters, 2)
288-
self.assertIsInstance(filters[0]['reason'], glm.ContentFilter.BlockedReason)
289-
self.assertEquals(filters[0]['reason'], glm.ContentFilter.BlockedReason.SAFETY)
290-
self.assertEquals(filters[0]['message'], 'unsafe')
292+
self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason)
293+
self.assertEqual(filters[0]["reason"], safety_types.BlockedReason.SAFETY)
294+
self.assertEqual(filters[0]["message"], "unsafe")
291295

292296
self.mock_response = glm.GenerateMessageResponse(
293297
candidates=[glm.Message(content="a", author="1")],
294298
filters=[
295-
glm.ContentFilter(reason=glm.ContentFilter.BlockedReason.BLOCKED_REASON_UNSPECIFIED)]
299+
glm.ContentFilter(
300+
reason=safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED
301+
)
302+
],
296303
)
297304

298-
response = response.reply('Does reply work?')
305+
response = response.reply("Does reply work?")
299306
filters = response.filters
300307
self.assertLen(filters, 1)
301-
self.assertIsInstance(filters[0]['reason'], glm.ContentFilter.BlockedReason)
302-
self.assertEquals(filters[0]['reason'], glm.ContentFilter.BlockedReason.BLOCKED_REASON_UNSPECIFIED)
308+
self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason)
309+
self.assertEqual(
310+
filters[0]["reason"], safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED
311+
)
303312

304313

305314
if __name__ == "__main__":

tests/test_discuss_async.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import unittest
1818

1919
if sys.version_info < (3, 11):
20-
import asynctest
21-
from asynctest import mock as async_mock
20+
import asynctest
21+
from asynctest import mock as async_mock
2222

2323
import google.ai.generativelanguage as glm
2424

@@ -29,11 +29,16 @@
2929
bases = (parameterized.TestCase,)
3030

3131
if sys.version_info < (3, 11):
32-
bases = bases + (asynctest.TestCase,)
32+
bases = bases + (asynctest.TestCase,)
33+
34+
unittest.skipIf(
35+
sys.version_info >= (3, 11), "asynctest is not suported on python 3.11+"
36+
)
37+
3338

34-
unittest.skipIf(sys.version_info >= (3,11), "asynctest is not suported on python 3.11+")
3539
class AsyncTests(*bases):
3640
if sys.version_info < (3, 11):
41+
3742
async def test_chat_async(self):
3843
client = async_mock.MagicMock()
3944

0 commit comments

Comments
 (0)