Skip to content

Commit 39daccd

Browse files
authored
Add stop_sequences to ModelSettings (#1419)
1 parent 1a60d8f commit 39daccd

18 files changed

+558
-21
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ async def _messages_create(
226226
tools=tools or NOT_GIVEN,
227227
tool_choice=tool_choice or NOT_GIVEN,
228228
stream=stream,
229+
stop_sequences=model_settings.get('stop_sequences', NOT_GIVEN),
229230
temperature=model_settings.get('temperature', NOT_GIVEN),
230231
top_p=model_settings.get('top_p', NOT_GIVEN),
231232
timeout=model_settings.get('timeout', NOT_GIVEN),

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,8 @@ def _map_inference_config(
294294
inference_config['temperature'] = temperature
295295
if top_p := model_settings.get('top_p'):
296296
inference_config['topP'] = top_p
297-
# TODO(Marcelo): This is not included in model_settings yet.
298-
# if stop_sequences := model_settings.get('stop_sequences'):
299-
# inference_config['stopSequences'] = stop_sequences
297+
if stop_sequences := model_settings.get('stop_sequences'):
298+
inference_config['stopSequences'] = stop_sequences
300299

301300
return inference_config
302301

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __init__(
118118
'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
119119
created using the other parameters.
120120
"""
121-
self._model_name: CohereModelName = model_name
121+
self._model_name = model_name
122122

123123
if isinstance(provider, str):
124124
provider = infer_provider(provider)
@@ -163,6 +163,7 @@ async def _chat(
163163
messages=cohere_messages,
164164
tools=tools or OMIT,
165165
max_tokens=model_settings.get('max_tokens', OMIT),
166+
stop_sequences=model_settings.get('stop_sequences', OMIT),
166167
temperature=model_settings.get('temperature', OMIT),
167168
p=model_settings.get('top_p', OMIT),
168169
seed=model_settings.get('seed', OMIT),

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ class _GeminiGenerationConfig(TypedDict, total=False):
506506
top_p: float
507507
presence_penalty: float
508508
frequency_penalty: float
509+
stop_sequences: list[str]
509510

510511

511512
class _GeminiContent(TypedDict):

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ async def _completions_create(
208208
parallel_tool_calls=model_settings.get('parallel_tool_calls', NOT_GIVEN),
209209
tools=tools or NOT_GIVEN,
210210
tool_choice=tool_choice or NOT_GIVEN,
211+
stop=model_settings.get('stop_sequences', NOT_GIVEN),
211212
stream=stream,
212213
max_tokens=model_settings.get('max_tokens', NOT_GIVEN),
213214
temperature=model_settings.get('temperature', NOT_GIVEN),

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ async def _completions_create(
199199
top_p=model_settings.get('top_p', 1),
200200
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
201201
random_seed=model_settings.get('seed', UNSET),
202+
stop=model_settings.get('stop_sequences', None),
202203
)
203204
except SDKError as e:
204205
if (status_code := e.status_code) >= 400:
@@ -236,6 +237,7 @@ async def _stream_completions_create(
236237
timeout_ms=self._get_timeout_ms(model_settings.get('timeout')),
237238
presence_penalty=model_settings.get('presence_penalty'),
238239
frequency_penalty=model_settings.get('frequency_penalty'),
240+
stop=model_settings.get('stop_sequences', None),
239241
)
240242

241243
elif model_request_parameters.result_tools:

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ async def _completions_create(
271271
tool_choice=tool_choice or NOT_GIVEN,
272272
stream=stream,
273273
stream_options={'include_usage': True} if stream else NOT_GIVEN,
274+
stop=model_settings.get('stop_sequences', NOT_GIVEN),
274275
max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN),
275276
temperature=model_settings.get('temperature', NOT_GIVEN),
276277
top_p=model_settings.get('top_p', NOT_GIVEN),
@@ -611,7 +612,7 @@ async def _responses_create(
611612
truncation=model_settings.get('openai_truncation', NOT_GIVEN),
612613
timeout=model_settings.get('timeout', NOT_GIVEN),
613614
reasoning=reasoning,
614-
user=model_settings.get('user', NOT_GIVEN),
615+
user=model_settings.get('openai_user', NOT_GIVEN),
615616
)
616617
except APIStatusError as e:
617618
if (status_code := e.status_code) >= 400:

pydantic_ai_slim/pydantic_ai/settings.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,19 @@ class ModelSettings(TypedDict, total=False):
128128
* Groq
129129
"""
130130

131+
stop_sequences: list[str]
132+
"""Sequences that will cause the model to stop generating.
133+
134+
Supported by:
135+
136+
* OpenAI
137+
* Anthropic
138+
* Bedrock
139+
* Mistral
140+
* Groq
141+
* Cohere
142+
"""
143+
131144

132145
def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | None) -> ModelSettings | None:
133146
"""Merge two sets of model settings, preferring the overrides.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- application/json
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- '193'
12+
content-type:
13+
- application/json
14+
host:
15+
- api.anthropic.com
16+
method: POST
17+
parsed_body:
18+
max_tokens: 1024
19+
messages:
20+
- content:
21+
- text: What is the capital of France?
22+
type: text
23+
role: user
24+
model: claude-3-5-sonnet-latest
25+
stop_sequences:
26+
- Paris
27+
stream: false
28+
uri: https://api.anthropic.com/v1/messages
29+
response:
30+
headers:
31+
connection:
32+
- keep-alive
33+
content-length:
34+
- '333'
35+
content-type:
36+
- application/json
37+
transfer-encoding:
38+
- chunked
39+
parsed_body:
40+
content:
41+
- text: 'The capital of France is '
42+
type: text
43+
id: msg_01FfkgikmbDFzn9XE1YYkJmA
44+
model: claude-3-5-sonnet-20241022
45+
role: assistant
46+
stop_reason: stop_sequence
47+
stop_sequence: Paris
48+
type: message
49+
usage:
50+
cache_creation_input_tokens: 0
51+
cache_read_input_tokens: 0
52+
input_tokens: 14
53+
output_tokens: 6
54+
status:
55+
code: 200
56+
message: OK
57+
version: 1
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
interactions:
2+
- request:
3+
body: '{"messages": [{"role": "user", "content": [{"text": "What is the capital of France?"}]}], "system": [], "inferenceConfig":
4+
{"stopSequences": ["Paris"]}}'
5+
headers:
6+
amz-sdk-invocation-id:
7+
- !!binary |
8+
YzVkZjljOTMtMDQ1Zi00NWE0LWJhY2YtMDAwMjdjYTg1NmRl
9+
amz-sdk-request:
10+
- !!binary |
11+
YXR0ZW1wdD0x
12+
content-length:
13+
- '152'
14+
content-type:
15+
- !!binary |
16+
YXBwbGljYXRpb24vanNvbg==
17+
method: POST
18+
uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/us.amazon.nova-micro-v1%3A0/converse
19+
response:
20+
headers:
21+
connection:
22+
- keep-alive
23+
content-length:
24+
- '209'
25+
content-type:
26+
- application/json
27+
parsed_body:
28+
metrics:
29+
latencyMs: 179
30+
output:
31+
message:
32+
content:
33+
- text: The capital of France is Paris
34+
role: assistant
35+
stopReason: end_turn
36+
usage:
37+
inputTokens: 7
38+
outputTokens: 6
39+
totalTokens: 13
40+
status:
41+
code: 200
42+
message: OK
43+
version: 1

0 commit comments

Comments
 (0)