Skip to content

Commit b2adad9

Browse files
Add Ability to Specify Gemini Safety Settings (#790)
Co-authored-by: Sydney Runkle <[email protected]> Co-authored-by: sydney-runkle <[email protected]>
1 parent 738c890 commit b2adad9

File tree

4 files changed

+169
-16
lines changed

4 files changed

+169
-16
lines changed

docs/agents.md

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -206,26 +206,44 @@ print(result_sync.data)
206206

207207
### Model specific settings
208208

209-
<!-- TODO: replace this with the gemini safety settings example once added via https://github.com/pydantic/pydantic-ai/issues/373 -->
210-
211-
If you wish to further customize model behavior, you can use a subclass of [`ModelSettings`][pydantic_ai.settings.ModelSettings], like [`AnthropicModelSettings`][pydantic_ai.models.anthropic.AnthropicModelSettings], associated with your model of choice.
209+
If you wish to further customize model behavior, you can use a subclass of [`ModelSettings`][pydantic_ai.settings.ModelSettings], like [`GeminiModelSettings`][pydantic_ai.models.gemini.GeminiModelSettings], associated with your model of choice.
212210

213211
For example:
214212

215213
```py
216-
from pydantic_ai import Agent
217-
from pydantic_ai.models.anthropic import AnthropicModelSettings
214+
from pydantic_ai import Agent, UnexpectedModelBehavior
215+
from pydantic_ai.models.gemini import GeminiModelSettings
218216

219-
agent = Agent('anthropic:claude-3-5-sonnet-latest')
217+
agent = Agent('google-gla:gemini-1.5-flash')
220218

221-
result_sync = agent.run_sync(
222-
'What is the capital of Italy?',
223-
model_settings=AnthropicModelSettings(anthropic_metadata={'user_id': 'my_user_id'}),
224-
)
225-
print(result_sync.data)
226-
#> Rome
219+
try:
220+
result = agent.run_sync(
221+
'Write a list of 5 very rude things that I might say to the universe after stubbing my toe in the dark:',
222+
model_settings=GeminiModelSettings(
223+
temperature=0.0, # general model settings can also be specified
224+
gemini_safety_settings=[
225+
{
226+
'category': 'HARM_CATEGORY_HARASSMENT',
227+
'threshold': 'BLOCK_LOW_AND_ABOVE',
228+
},
229+
{
230+
'category': 'HARM_CATEGORY_HATE_SPEECH',
231+
'threshold': 'BLOCK_LOW_AND_ABOVE',
232+
},
233+
],
234+
),
235+
)
236+
except UnexpectedModelBehavior as e:
237+
print(e) # (1)!
238+
"""
239+
Safety settings triggered, body:
240+
<safety settings details>
241+
"""
227242
```
228243

244+
1. This error is raised because the safety thresholds were exceeded.
245+
Generally, `result` would contain a normal `ModelResponse`.
246+
229247
## Runs vs. Conversations
230248

231249
An agent **run** might represent an entire conversation — there's no limit to how many messages can be exchanged in a single run. However, a **conversation** might also be composed of multiple runs, especially if you need to maintain state between separate interactions or API calls.

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
class GeminiModelSettings(ModelSettings):
5858
"""Settings used for a Gemini model request."""
5959

60-
# This class is a placeholder for any future gemini-specific settings
60+
gemini_safety_settings: list[GeminiSafetySettings]
6161

6262

6363
@dataclass(init=False)
@@ -192,6 +192,8 @@ async def _make_request(
192192
generation_config['presence_penalty'] = presence_penalty
193193
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
194194
generation_config['frequency_penalty'] = frequency_penalty
195+
if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) != []:
196+
request_data['safety_settings'] = gemini_safety_settings
195197
if generation_config:
196198
request_data['generation_config'] = generation_config
197199

@@ -220,6 +222,11 @@ async def _make_request(
220222
def _process_response(self, response: _GeminiResponse) -> ModelResponse:
221223
if len(response['candidates']) != 1:
222224
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response')
225+
if 'content' not in response['candidates'][0]:
226+
if response['candidates'][0].get('finish_reason') == 'SAFETY':
227+
raise UnexpectedModelBehavior('Safety settings triggered', str(response))
228+
else:
229+
raise UnexpectedModelBehavior('Content field missing from Gemini response', str(response))
223230
parts = response['candidates'][0]['content']['parts']
224231
return _process_response_from_parts(parts, model_name=self.model_name)
225232

@@ -237,7 +244,7 @@ async def _process_streamed_response(self, http_response: HTTPResponse) -> Strea
237244
)
238245
if responses:
239246
last = responses[-1]
240-
if last['candidates'] and last['candidates'][0]['content']['parts']:
247+
if last['candidates'] and last['candidates'][0].get('content', {}).get('parts'):
241248
start_response = last
242249
break
243250

@@ -310,6 +317,8 @@ class GeminiStreamedResponse(StreamedResponse):
310317
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
311318
async for gemini_response in self._get_gemini_responses():
312319
candidate = gemini_response['candidates'][0]
320+
if 'content' not in candidate:
321+
raise UnexpectedModelBehavior('Streamed response has no content field')
313322
gemini_part: _GeminiPartUnion
314323
for gemini_part in candidate['content']['parts']:
315324
if 'text' in gemini_part:
@@ -383,6 +392,7 @@ class _GeminiRequest(TypedDict):
383392
contents: list[_GeminiContent]
384393
tools: NotRequired[_GeminiTools]
385394
tool_config: NotRequired[_GeminiToolConfig]
395+
safety_settings: NotRequired[list[GeminiSafetySettings]]
386396
# we don't implement `generationConfig`, instead we use a named tool for the response
387397
system_instruction: NotRequired[_GeminiTextContent]
388398
"""
@@ -392,6 +402,38 @@ class _GeminiRequest(TypedDict):
392402
generation_config: NotRequired[_GeminiGenerationConfig]
393403

394404

405+
class GeminiSafetySettings(TypedDict):
406+
"""Safety settings options for Gemini model request.
407+
408+
See [Gemini API docs](https://ai.google.dev/gemini-api/docs/safety-settings) for safety category and threshold descriptions.
409+
For an example on how to use `GeminiSafetySettings`, see [here](../../agents.md#model-specific-settings).
410+
"""
411+
412+
category: Literal[
413+
'HARM_CATEGORY_UNSPECIFIED',
414+
'HARM_CATEGORY_HARASSMENT',
415+
'HARM_CATEGORY_HATE_SPEECH',
416+
'HARM_CATEGORY_SEXUALLY_EXPLICIT',
417+
'HARM_CATEGORY_DANGEROUS_CONTENT',
418+
'HARM_CATEGORY_CIVIC_INTEGRITY',
419+
]
420+
"""
421+
Safety settings category.
422+
"""
423+
424+
threshold: Literal[
425+
'HARM_BLOCK_THRESHOLD_UNSPECIFIED',
426+
'BLOCK_LOW_AND_ABOVE',
427+
'BLOCK_MEDIUM_AND_ABOVE',
428+
'BLOCK_ONLY_HIGH',
429+
'BLOCK_NONE',
430+
'OFF',
431+
]
432+
"""
433+
Safety settings threshold.
434+
"""
435+
436+
395437
class _GeminiGenerationConfig(TypedDict, total=False):
396438
"""Schema for an API request to the Gemini API.
397439
@@ -568,8 +610,8 @@ class _GeminiResponse(TypedDict):
568610
class _GeminiCandidates(TypedDict):
569611
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""
570612

571-
content: _GeminiContent
572-
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
613+
content: NotRequired[_GeminiContent]
614+
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS', 'SAFETY'], pydantic.Field(alias='finishReason')]]
573615
"""
574616
See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
575617
but let's wait until we see them and know what they mean to add them here.
@@ -617,6 +659,7 @@ class _GeminiSafetyRating(TypedDict):
617659
'HARM_CATEGORY_CIVIC_INTEGRITY',
618660
]
619661
probability: Literal['NEGLIGIBLE', 'LOW', 'MEDIUM', 'HIGH']
662+
blocked: NotRequired[bool]
620663

621664

622665
class _GeminiPromptFeedback(TypedDict):

tests/models/test_gemini.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pydantic_ai.models.gemini import (
2929
ApiKeyAuth,
3030
GeminiModel,
31+
GeminiModelSettings,
3132
_content_model_response,
3233
_function_call_part_from_call,
3334
_gemini_response_ta,
@@ -37,6 +38,7 @@
3738
_GeminiFunction,
3839
_GeminiFunctionCallingConfig,
3940
_GeminiResponse,
41+
_GeminiSafetyRating,
4042
_GeminiTextPart,
4143
_GeminiToolConfig,
4244
_GeminiTools,
@@ -865,3 +867,90 @@ def handler(request: httpx.Request) -> httpx.Response:
865867
},
866868
)
867869
assert result.data == 'world'
870+
871+
872+
def gemini_no_content_response(
873+
safety_ratings: list[_GeminiSafetyRating], finish_reason: Literal['SAFETY'] | None = 'SAFETY'
874+
) -> _GeminiResponse:
875+
candidate = _GeminiCandidates(safety_ratings=safety_ratings)
876+
if finish_reason:
877+
candidate['finish_reason'] = finish_reason
878+
return _GeminiResponse(candidates=[candidate], usage_metadata=example_usage())
879+
880+
881+
async def test_safety_settings_unsafe(
882+
client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None
883+
) -> None:
884+
try:
885+
886+
def handler(request: httpx.Request) -> httpx.Response:
887+
safety_settings = json.loads(request.content)['safety_settings']
888+
assert safety_settings == [
889+
{'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
890+
{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
891+
]
892+
893+
return httpx.Response(
894+
200,
895+
content=_gemini_response_ta.dump_json(
896+
gemini_no_content_response(
897+
finish_reason='SAFETY',
898+
safety_ratings=[
899+
{'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'MEDIUM', 'blocked': True}
900+
],
901+
),
902+
by_alias=True,
903+
),
904+
headers={'Content-Type': 'application/json'},
905+
)
906+
907+
gemini_client = client_with_handler(handler)
908+
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client, api_key='mock')
909+
agent = Agent(m)
910+
911+
await agent.run(
912+
'a request for something rude',
913+
model_settings=GeminiModelSettings(
914+
gemini_safety_settings=[
915+
{'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
916+
{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
917+
]
918+
),
919+
)
920+
except UnexpectedModelBehavior as e:
921+
assert repr(e) == "UnexpectedModelBehavior('Safety settings triggered')"
922+
923+
924+
async def test_safety_settings_safe(
925+
client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None
926+
) -> None:
927+
def handler(request: httpx.Request) -> httpx.Response:
928+
safety_settings = json.loads(request.content)['safety_settings']
929+
assert safety_settings == [
930+
{'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
931+
{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
932+
]
933+
934+
return httpx.Response(
935+
200,
936+
content=_gemini_response_ta.dump_json(
937+
gemini_response(_content_model_response(ModelResponse(parts=[TextPart('world')]))),
938+
by_alias=True,
939+
),
940+
headers={'Content-Type': 'application/json'},
941+
)
942+
943+
gemini_client = client_with_handler(handler)
944+
m = GeminiModel('gemini-1.5-flash', http_client=gemini_client, api_key='mock')
945+
agent = Agent(m)
946+
947+
result = await agent.run(
948+
'hello',
949+
model_settings=GeminiModelSettings(
950+
gemini_safety_settings=[
951+
{'category': 'HARM_CATEGORY_CIVIC_INTEGRITY', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
952+
{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_LOW_AND_ABOVE'},
953+
]
954+
),
955+
)
956+
assert result.data == 'world'

tests/test_examples.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pytest_mock import MockerFixture
1818

1919
from pydantic_ai._utils import group_by_temporal
20+
from pydantic_ai.exceptions import UnexpectedModelBehavior
2021
from pydantic_ai.messages import (
2122
ModelMessage,
2223
ModelResponse,
@@ -288,6 +289,8 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes
288289
)
289290
]
290291
)
292+
elif m.content.startswith('Write a list of 5 very rude things that I might say'):
293+
raise UnexpectedModelBehavior('Safety settings triggered', body='<safety settings details>')
291294
elif m.content.startswith('<examples>\n <user>'):
292295
return ModelResponse(parts=[ToolCallPart(tool_name='final_result_EmailOk', args={})])
293296
elif m.content == 'Ask a simple question with a single correct answer.' and len(messages) > 2:

0 commit comments

Comments
 (0)