Skip to content

Commit 386994a

Browse files
authored
Quick safety filtering: Allow safety_settings="block_none" (#347)
* allow safety_settings='off' Change-Id: Ica10b399177301073424a98cb3a8b0736dc216b4 * Fix tests. Change-Id: I06cfd07397e984b9fb757b2831b419eefb8aff98 * license Change-Id: Ifa4843831b9c1479198c2b45c5b5abad8410f448 * format Change-Id: I534837c309121cda9c8947acdd6c126c9c730d62 * add test Change-Id: I9bce66322d64b3d6296d4db7cc0a7b7b9a78763b
1 parent 75b97db commit 386994a

File tree

3 files changed

+109
-21
lines changed

3 files changed

+109
-21
lines changed

google/generativeai/types/safety_types.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,25 +201,52 @@ class LooseSafetySettingDict(TypedDict):
201201
EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions]
202202
EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions]
203203

204-
SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None]
204+
SafetySettingOptions = Union[
205+
HarmBlockThresholdOptions, EasySafetySetting, Iterable[LooseSafetySettingDict], None
206+
]
207+
208+
209+
def _expand_block_threshold(block_threshold: HarmBlockThresholdOptions):
210+
block_threshold = to_block_threshold(block_threshold)
211+
set(_HARM_CATEGORIES.values())
212+
return {category: block_threshold for category in set(_HARM_CATEGORIES.values())}
205213

206214

207215
def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict:
208216
if settings is None:
209217
return {}
210-
elif isinstance(settings, Mapping):
218+
219+
if isinstance(settings, (int, str, HarmBlockThreshold)):
220+
settings = _expand_block_threshold(settings)
221+
222+
if isinstance(settings, Mapping):
211223
return {to_harm_category(key): to_block_threshold(value) for key, value in settings.items()}
224+
212225
else: # Iterable
213-
return {
214-
to_harm_category(d["category"]): to_block_threshold(d["threshold"]) for d in settings
215-
}
226+
result = {}
227+
for setting in settings:
228+
if isinstance(setting, glm.SafetySetting):
229+
result[to_harm_category(setting.category)] = to_block_threshold(setting.threshold)
230+
elif isinstance(setting, dict):
231+
result[to_harm_category(setting["category"])] = to_block_threshold(
232+
setting["threshold"]
233+
)
234+
else:
235+
raise ValueError(
236+
f"Could not understand safety setting:\n {type(setting)=}\n {setting=}"
237+
)
238+
return result
216239

217240

218241
def normalize_safety_settings(
219242
settings: SafetySettingOptions,
220243
) -> list[SafetySettingDict] | None:
221244
if settings is None:
222245
return None
246+
247+
if isinstance(settings, (int, str, HarmBlockThreshold)):
248+
settings = _expand_block_threshold(settings)
249+
223250
if isinstance(settings, Mapping):
224251
return [
225252
{

tests/test_generative_models.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,12 @@ def test_generation_config_overwrite(self, config1, config2):
155155

156156
@parameterized.named_parameters(
157157
["dict", {"danger": "low"}, {"danger": "high"}],
158+
["quick", "low", "high"],
158159
[
159160
"list-dict",
160161
[
161162
dict(
162-
category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS,
163+
category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
163164
threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
164165
),
165166
],
@@ -171,44 +172,47 @@ def test_generation_config_overwrite(self, config1, config2):
171172
"object",
172173
[
173174
glm.SafetySetting(
174-
category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS,
175+
category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
175176
threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
176177
),
177178
],
178179
[
179180
glm.SafetySetting(
180-
category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS,
181-
threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
181+
category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
182+
threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,
182183
),
183184
],
184185
],
185186
)
186187
def test_safety_overwrite(self, safe1, safe2):
187188
# Safety
188-
model = generative_models.GenerativeModel("gemini-pro", safety_settings={"danger": "low"})
189+
model = generative_models.GenerativeModel("gemini-pro", safety_settings=safe1)
189190

190191
self.responses["generate_content"] = [
191192
simple_response(" world!"),
192193
simple_response(" world!"),
193194
]
194195

195196
_ = model.generate_content("hello")
197+
198+
danger = [
199+
s
200+
for s in self.observed_requests[-1].safety_settings
201+
if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
202+
]
196203
self.assertEqual(
197-
self.observed_requests[-1].safety_settings[0].category,
198-
glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
199-
)
200-
self.assertEqual(
201-
self.observed_requests[-1].safety_settings[0].threshold,
204+
danger[0].threshold,
202205
glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
203206
)
204207

205-
_ = model.generate_content("hello", safety_settings={"danger": "high"})
206-
self.assertEqual(
207-
self.observed_requests[-1].safety_settings[0].category,
208-
glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
209-
)
208+
_ = model.generate_content("hello", safety_settings=safe2)
209+
danger = [
210+
s
211+
for s in self.observed_requests[-1].safety_settings
212+
if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT
213+
]
210214
self.assertEqual(
211-
self.observed_requests[-1].safety_settings[0].threshold,
215+
danger[0].threshold,
212216
glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH,
213217
)
214218

tests/test_safety.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2023 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from absl.testing import absltest
17+
from absl.testing import parameterized
18+
import google.ai.generativelanguage as glm
19+
from google.generativeai.types import safety_types
20+
21+
22+
class SafetyTests(parameterized.TestCase):
23+
"""Tests are in order with the design doc."""
24+
25+
@parameterized.named_parameters(
26+
["block_threshold", glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE],
27+
["block_threshold2", "medium"],
28+
["block_threshold3", 2],
29+
["dict", {"danger": "medium"}],
30+
["dict2", {"danger": 2}],
31+
["dict3", {"danger": glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}],
32+
[
33+
"list-dict",
34+
[
35+
dict(
36+
category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
37+
threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
38+
),
39+
],
40+
],
41+
[
42+
"list-dict2",
43+
[
44+
dict(category="danger", threshold="med"),
45+
],
46+
],
47+
)
48+
def test_safety_overwrite(self, setting):
49+
setting = safety_types.to_easy_safety_dict(setting)
50+
self.assertEqual(
51+
setting[glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT],
52+
glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
53+
)
54+
55+
56+
if __name__ == "__main__":
57+
absltest.main()

0 commit comments

Comments
 (0)