Skip to content

Commit 8abd8f7

Browse files
authored
Allow more flexible safety_settings. (#44)
* Allow more flexible safety_settings. * fix 3.9
1 parent 7cd3d52 commit 8abd8f7

File tree

3 files changed

+142
-15
lines changed

3 files changed

+142
-15
lines changed

google/generativeai/text.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,12 @@ def _make_generate_text_request(
4646
max_output_tokens: int | None = None,
4747
top_p: int | None = None,
4848
top_k: int | None = None,
49-
safety_settings: List[safety_types.SafetySettingDict] | None = None,
49+
safety_settings: safety_types.SafetySettingOptions | None = None,
5050
stop_sequences: str | Iterable[str] | None = None,
5151
) -> glm.GenerateTextRequest:
5252
model = model_types.make_model_name(model)
5353
prompt = _make_text_prompt(prompt=prompt)
54+
safety_settings = safety_types.normalize_safety_settings(safety_settings)
5455
if isinstance(stop_sequences, str):
5556
stop_sequences = [stop_sequences]
5657
if stop_sequences:
@@ -78,7 +79,7 @@ def generate_text(
7879
max_output_tokens: int | None = None,
7980
top_p: float | None = None,
8081
top_k: float | None = None,
81-
safety_settings: Iterable[safety_types.SafetySettingDict] | None = None,
82+
safety_settings: safety_types.SafetySettingOptions | None = None,
8283
stop_sequences: str | Iterable[str] | None = None,
8384
client: glm.TextServiceClient | None = None,
8485
) -> text_types.Completion:

google/generativeai/types/safety_types.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
from __future__ import annotations
1616

1717
import enum
18+
from collections.abc import Mapping
19+
1820
from google.ai import generativelanguage as glm
1921
from google.generativeai import docstring_utils
2022
import typing
21-
from typing import Iterable, List, TypedDict
23+
from typing import Iterable, Dict, Iterable, List, TypedDict, Union
2224

2325
__all__ = [
2426
"HarmCategory",
@@ -37,6 +39,69 @@
3739
HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold
3840
BlockedReason = glm.ContentFilter.BlockedReason
3941

42+
HarmCategoryOptions = Union[str, int, HarmCategory]
43+
44+
_HARM_CATEGORIES: Dict[HarmCategoryOptions, HarmCategory] = {
45+
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmCategory.HARM_CATEGORY_UNSPECIFIED,
46+
0: HarmCategory.HARM_CATEGORY_UNSPECIFIED,
47+
HarmCategory.HARM_CATEGORY_DEROGATORY: HarmCategory.HARM_CATEGORY_DEROGATORY,
48+
1: HarmCategory.HARM_CATEGORY_DEROGATORY,
49+
"derogatory": HarmCategory.HARM_CATEGORY_DEROGATORY,
50+
HarmCategory.HARM_CATEGORY_TOXICITY: HarmCategory.HARM_CATEGORY_TOXICITY,
51+
2: HarmCategory.HARM_CATEGORY_TOXICITY,
52+
"toxicity": HarmCategory.HARM_CATEGORY_TOXICITY,
53+
"toxic": HarmCategory.HARM_CATEGORY_TOXICITY,
54+
HarmCategory.HARM_CATEGORY_VIOLENCE: HarmCategory.HARM_CATEGORY_VIOLENCE,
55+
3: HarmCategory.HARM_CATEGORY_VIOLENCE,
56+
"violence": HarmCategory.HARM_CATEGORY_VIOLENCE,
57+
"violent": HarmCategory.HARM_CATEGORY_VIOLENCE,
58+
HarmCategory.HARM_CATEGORY_SEXUAL: HarmCategory.HARM_CATEGORY_SEXUAL,
59+
4: HarmCategory.HARM_CATEGORY_SEXUAL,
60+
"sexual": HarmCategory.HARM_CATEGORY_SEXUAL,
61+
"sex": HarmCategory.HARM_CATEGORY_SEXUAL,
62+
HarmCategory.HARM_CATEGORY_MEDICAL: HarmCategory.HARM_CATEGORY_MEDICAL,
63+
5: HarmCategory.HARM_CATEGORY_MEDICAL,
64+
"medical": HarmCategory.HARM_CATEGORY_MEDICAL,
65+
"med": HarmCategory.HARM_CATEGORY_MEDICAL,
66+
HarmCategory.HARM_CATEGORY_DANGEROUS: HarmCategory.HARM_CATEGORY_DANGEROUS,
67+
6: HarmCategory.HARM_CATEGORY_DANGEROUS,
68+
"danger": HarmCategory.HARM_CATEGORY_DANGEROUS,
69+
"dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS,
70+
}
71+
72+
73+
def to_harm_category(x: HarmCategoryOptions) -> HarmCategory:
74+
if isinstance(x, str):
75+
x = x.lower()
76+
return _HARM_CATEGORIES[x]
77+
78+
79+
HarmBlockThresholdOptions = Union[str, int, HarmBlockThreshold]
80+
81+
_BLOCK_THRESHOLDS: Dict[HarmBlockThresholdOptions, HarmBlockThreshold] = {
82+
HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
83+
0: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
84+
HarmBlockThreshold.BLOCK_LOW_AND_ABOVE: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
85+
1: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
86+
"low": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
87+
HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
88+
2: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
89+
"medium": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
90+
"med": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
91+
HarmBlockThreshold.BLOCK_ONLY_HIGH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
92+
3: HarmBlockThreshold.BLOCK_ONLY_HIGH,
93+
"high": HarmBlockThreshold.BLOCK_ONLY_HIGH,
94+
HarmBlockThreshold.BLOCK_NONE: HarmBlockThreshold.BLOCK_NONE,
95+
4: HarmBlockThreshold.BLOCK_NONE,
96+
"block_none": HarmBlockThreshold.BLOCK_NONE,
97+
}
98+
99+
100+
def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmCategory:
101+
if isinstance(x, str):
102+
x = x.lower()
103+
return _BLOCK_THRESHOLDS[x]
104+
40105

41106
class ContentFilterDict(TypedDict):
42107
reason: BlockedReason
@@ -83,6 +148,35 @@ class SafetySettingDict(TypedDict):
83148
__doc__ = docstring_utils.strip_oneof(glm.SafetySetting.__doc__)
84149

85150

151+
class LooseSafetySettingDict(TypedDict):
152+
category: HarmCategoryOptions
153+
threshold: HarmBlockThresholdOptions
154+
155+
156+
EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions]
157+
158+
SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None]
159+
160+
161+
def normalize_safety_settings(
162+
settings: SafetySettingOptions,
163+
) -> list[SafetySettingDict] | None:
164+
if settings is None:
165+
return None
166+
if isinstance(settings, Mapping):
167+
return [
168+
{"category": to_harm_category(key), "threshold": to_block_threshold(value)}
169+
for key, value in settings.items()
170+
]
171+
return [
172+
{
173+
"category": to_harm_category(d["category"]),
174+
"threshold": to_block_threshold(d["threshold"]),
175+
}
176+
for d in settings
177+
]
178+
179+
86180
def convert_setting_to_enum(setting: dict) -> SafetySettingDict:
87181
return {
88182
"category": HarmCategory(setting["category"]),

tests/test_text.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,51 @@ def test_stop_string(self):
154154
# Just make sure it made it into the request object.
155155
self.assertEqual(self.observed_request.stop_sequences, ["stop"])
156156

157-
def test_safety_settings(self):
158-
result = text_service.generate_text(
159-
prompt="Say something wicked.",
160-
safety_settings=[
161-
{
162-
"category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL,
163-
"threshold": safety_types.HarmBlockThreshold.BLOCK_NONE,
164-
},
165-
{
166-
"category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE,
167-
"threshold": safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
157+
@parameterized.named_parameters(
158+
[
159+
dict(
160+
testcase_name="basic",
161+
safety_settings=[
162+
{
163+
"category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL,
164+
"threshold": safety_types.HarmBlockThreshold.BLOCK_NONE,
165+
},
166+
{
167+
"category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE,
168+
"threshold": safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
169+
},
170+
],
171+
),
172+
dict(
173+
testcase_name="strings",
174+
safety_settings=[
175+
{
176+
"category": "medical",
177+
"threshold": "block_none",
178+
},
179+
{
180+
"category": "violent",
181+
"threshold": "low",
182+
},
183+
],
184+
),
185+
dict(
186+
testcase_name="flat",
187+
safety_settings={"medical": "block_none", "sex": "low"},
188+
),
189+
dict(
190+
testcase_name="mixed",
191+
safety_settings={
192+
"medical": safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
193+
safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE: 1,
168194
},
169-
],
195+
),
196+
]
197+
)
198+
def test_safety_settings(self, safety_settings):
199+
# This test really just checks that the safety_settings get converted to a proto.
200+
result = text_service.generate_text(
201+
prompt="Say something wicked.", safety_settings=safety_settings
170202
)
171203

172204
self.assertEqual(

0 commit comments

Comments
 (0)