Skip to content

Commit 836fe6e

Browse files
committed
Add test
1 parent d9ae4ae commit 836fe6e

File tree

6 files changed

+377
-69
lines changed

6 files changed

+377
-69
lines changed

docs/api/models/function.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ async def model_function(
3636
print(info)
3737
"""
3838
AgentInfo(
39-
function_tools=[], allow_text_output=True, output_tools=[], model_settings=None
39+
function_tools=[],
40+
allow_text_output=True,
41+
output_tools=[],
42+
model_settings=None,
43+
instructions=None,
4044
)
4145
"""
4246
return ModelResponse(parts=[TextPart('hello world')])

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 33 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,12 @@ class ModelRequestParameters:
317317
def tool_defs(self) -> dict[str, ToolDefinition]:
318318
return {tool_def.name: tool_def for tool_def in [*self.function_tools, *self.output_tools]}
319319

320+
@cached_property
321+
def prompted_output_instructions(self) -> str | None:
322+
if self.output_mode == 'prompted' and self.prompted_output_template and self.output_object:
323+
return PromptedOutputSchema.build_instructions(self.prompted_output_template, self.output_object)
324+
return None
325+
320326
__repr__ = _utils.dataclasses_no_defaults_repr
321327

322328

@@ -417,56 +423,44 @@ def prepare_request(
417423
"""
418424
model_settings = merge_model_settings(self.settings, model_settings)
419425

420-
model_request_parameters = self.customize_request_parameters(model_request_parameters)
426+
params = self.customize_request_parameters(model_request_parameters)
421427

422-
if builtin_tools := model_request_parameters.builtin_tools:
428+
if builtin_tools := params.builtin_tools:
423429
# Deduplicate builtin tools
424-
model_request_parameters = replace(
425-
model_request_parameters,
430+
params = replace(
431+
params,
426432
builtin_tools=list({tool.unique_id: tool for tool in builtin_tools}.values()),
427433
)
428434

429-
if model_request_parameters.output_mode == 'auto':
435+
if params.output_mode == 'auto':
430436
output_mode = self.profile.default_structured_output_mode
431-
model_request_parameters = replace(
432-
model_request_parameters,
437+
params = replace(
438+
params,
433439
output_mode=output_mode,
434440
allow_text_output=output_mode in ('native', 'prompted'),
435441
)
436442

437-
if model_request_parameters.output_mode in ('native', 'prompted'):
438-
assert model_request_parameters.output_object
439-
440-
if model_request_parameters.output_tools:
441-
model_request_parameters = replace(model_request_parameters, output_tools=[])
442-
else:
443-
if model_request_parameters.output_object:
444-
model_request_parameters = replace(model_request_parameters, output_object=None)
445-
446-
match model_request_parameters.output_mode:
447-
case 'native':
448-
if not self.profile.supports_json_schema_output:
449-
raise UserError('Native structured output is not supported by this model.')
450-
451-
if model_request_parameters.prompted_output_template:
452-
model_request_parameters = replace(model_request_parameters, prompted_output_template=None)
453-
case 'prompted':
454-
if not model_request_parameters.prompted_output_template:
455-
model_request_parameters = replace(
456-
model_request_parameters, prompted_output_template=self.profile.prompted_output_template
457-
)
458-
case 'tool':
459-
assert model_request_parameters.output_tools or model_request_parameters.function_tools
460-
461-
if not self.profile.supports_tools:
462-
raise UserError('Tool output is not supported by this model.')
463-
case _:
464-
pass
465-
466-
if model_request_parameters.allow_image_output and not self.profile.supports_image_output:
443+
# Reset irrelevant fields
444+
if params.output_tools and params.output_mode != 'tool':
445+
params = replace(params, output_tools=[])
446+
if params.output_object and params.output_mode not in ('native', 'prompted'):
447+
params = replace(params, output_object=None)
448+
if params.prompted_output_template and params.output_mode != 'prompted':
449+
params = replace(params, prompted_output_template=None)
450+
451+
# Set default prompted output template
452+
if params.output_mode == 'prompted' and not params.prompted_output_template:
453+
params = replace(params, prompted_output_template=self.profile.prompted_output_template)
454+
455+
# Check if output mode is supported
456+
if params.output_mode == 'native' and not self.profile.supports_json_schema_output:
457+
raise UserError('Native structured output is not supported by this model.')
458+
if params.output_mode == 'tool' and not self.profile.supports_tools:
459+
raise UserError('Tool output is not supported by this model.')
460+
if params.allow_image_output and not self.profile.supports_image_output:
467461
raise UserError('Image output is not supported by this model.')
468462

469-
return model_settings, model_request_parameters
463+
return model_settings, params
470464

471465
@property
472466
@abstractmethod
@@ -547,17 +541,7 @@ def _get_instructions(
547541
if all(p.part_kind == 'tool-return' or p.part_kind == 'retry-prompt' for p in most_recent_request.parts):
548542
instructions = second_most_recent_request.instructions
549543

550-
# TODO (DouweM): This will now not be included in ModelRequest.instructions anymore, nor in OTel. -- especially the latter may be a problem?
551-
# Unless full model_request_parameters (after processing by model) are already sent
552-
if (
553-
model_request_parameters
554-
and model_request_parameters.output_mode == 'prompted'
555-
and model_request_parameters.prompted_output_template
556-
and model_request_parameters.output_object
557-
):
558-
output_instructions = PromptedOutputSchema.build_instructions(
559-
model_request_parameters.prompted_output_template, model_request_parameters.output_object
560-
)
544+
if model_request_parameters and (output_instructions := model_request_parameters.prompted_output_instructions):
561545
if instructions:
562546
instructions = '\n\n'.join([instructions, output_instructions])
563547
else:

pydantic_ai_slim/pydantic_ai/models/fallback.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,15 @@ async def request(
8080

8181
for model in self.models:
8282
try:
83+
_, prepared_parameters = model.prepare_request(model_settings, model_request_parameters)
8384
response = await model.request(messages, model_settings, model_request_parameters)
8485
except Exception as exc:
8586
if self._fallback_on(exc):
8687
exceptions.append(exc)
8788
continue
8889
raise exc
8990

90-
self._set_span_attributes(model)
91+
self._set_span_attributes(model, prepared_parameters)
9192
return response
9293

9394
raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)
@@ -106,6 +107,7 @@ async def request_stream(
106107
for model in self.models:
107108
async with AsyncExitStack() as stack:
108109
try:
110+
_, prepared_parameters = model.prepare_request(model_settings, model_request_parameters)
109111
response = await stack.enter_async_context(
110112
model.request_stream(messages, model_settings, model_request_parameters, run_context)
111113
)
@@ -115,7 +117,7 @@ async def request_stream(
115117
continue
116118
raise exc # pragma: no cover
117119

118-
self._set_span_attributes(model)
120+
self._set_span_attributes(model, prepared_parameters)
119121
yield response
120122
return
121123

@@ -128,13 +130,23 @@ def profile(self) -> ModelProfile:
128130
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
129131
return model_request_parameters
130132

131-
def _set_span_attributes(self, model: Model):
133+
def prepare_request(
134+
self, model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters
135+
) -> tuple[ModelSettings | None, ModelRequestParameters]:
136+
return model_settings, model_request_parameters
137+
138+
def _set_span_attributes(self, model: Model, model_request_parameters: ModelRequestParameters):
132139
with suppress(Exception):
133140
span = get_current_span()
134141
if span.is_recording():
135142
attributes = getattr(span, 'attributes', {})
136143
if attributes.get('gen_ai.request.model') == self.model_name: # pragma: no branch
137-
span.set_attributes(InstrumentedModel.model_attributes(model))
144+
span.set_attributes(
145+
{
146+
**InstrumentedModel.model_attributes(model),
147+
**InstrumentedModel.model_request_parameters_attributes(model_request_parameters),
148+
}
149+
)
138150

139151

140152
def _default_fallback_condition_factory(exceptions: tuple[type[Exception], ...]) -> Callable[[Exception], bool]:

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ async def request(
135135
allow_text_output=model_request_parameters.allow_text_output,
136136
output_tools=model_request_parameters.output_tools,
137137
model_settings=model_settings,
138+
model_request_parameters=model_request_parameters,
138139
instructions=self._get_instructions(messages, model_request_parameters),
139140
)
140141

@@ -169,6 +170,7 @@ async def request_stream(
169170
allow_text_output=model_request_parameters.allow_text_output,
170171
output_tools=model_request_parameters.output_tools,
171172
model_settings=model_settings,
173+
model_request_parameters=model_request_parameters,
172174
instructions=self._get_instructions(messages, model_request_parameters),
173175
)
174176

@@ -218,6 +220,8 @@ class AgentInfo:
218220
"""The tools that can called to produce the final output of the run."""
219221
model_settings: ModelSettings | None
220222
"""The model settings passed to the run call."""
223+
model_request_parameters: ModelRequestParameters
224+
"""The model request parameters passed to the run call."""
221225
instructions: str | None
222226
"""The instructions passed to model."""
223227

pydantic_ai_slim/pydantic_ai/models/instrumented.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -178,17 +178,20 @@ def __init__(
178178
description='Monetary cost',
179179
)
180180

181-
def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]:
181+
def messages_to_otel_events(
182+
self, messages: list[ModelMessage], parameters: ModelRequestParameters | None = None
183+
) -> list[Event]:
182184
"""Convert a list of model messages to OpenTelemetry events.
183185
184186
Args:
185187
messages: The messages to convert.
188+
parameters: The model request parameters.
186189
187190
Returns:
188191
A list of OpenTelemetry events.
189192
"""
190193
events: list[Event] = []
191-
instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage]
194+
instructions = InstrumentedModel._get_instructions(messages, parameters) # pyright: ignore [reportPrivateUsage]
192195
if instructions is not None:
193196
events.append(
194197
Event(
@@ -235,10 +238,17 @@ def messages_to_otel_messages(self, messages: list[ModelMessage]) -> list[_otel_
235238
result.append(otel_message)
236239
return result
237240

238-
def handle_messages(self, input_messages: list[ModelMessage], response: ModelResponse, system: str, span: Span):
241+
def handle_messages(
242+
self,
243+
input_messages: list[ModelMessage],
244+
response: ModelResponse,
245+
system: str,
246+
span: Span,
247+
parameters: ModelRequestParameters | None = None,
248+
):
239249
if self.version == 1:
240-
events = self.messages_to_otel_events(input_messages)
241-
for event in self.messages_to_otel_events([response]):
250+
events = self.messages_to_otel_events(input_messages, parameters)
251+
for event in self.messages_to_otel_events([response], parameters):
242252
events.append(
243253
Event(
244254
'gen_ai.choice',
@@ -258,7 +268,7 @@ def handle_messages(self, input_messages: list[ModelMessage], response: ModelRes
258268
output_messages = self.messages_to_otel_messages([response])
259269
assert len(output_messages) == 1
260270
output_message = output_messages[0]
261-
instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage]
271+
instructions = InstrumentedModel._get_instructions(input_messages, parameters) # pyright: ignore [reportPrivateUsage]
262272
system_instructions_attributes = self.system_instructions_attributes(instructions)
263273
attributes: dict[str, AttributeValue] = {
264274
'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)),
@@ -360,7 +370,7 @@ async def request(
360370
)
361371
with self._instrument(messages, prepared_settings, prepared_parameters) as finish:
362372
response = await self.wrapped.request(messages, model_settings, model_request_parameters)
363-
finish(response)
373+
finish(response, prepared_parameters)
364374
return response
365375

366376
@asynccontextmanager
@@ -384,15 +394,15 @@ async def request_stream(
384394
yield response_stream
385395
finally:
386396
if response_stream: # pragma: no branch
387-
finish(response_stream.get())
397+
finish(response_stream.get(), prepared_parameters)
388398

389399
@contextmanager
390400
def _instrument(
391401
self,
392402
messages: list[ModelMessage],
393403
model_settings: ModelSettings | None,
394404
model_request_parameters: ModelRequestParameters,
395-
) -> Iterator[Callable[[ModelResponse], None]]:
405+
) -> Iterator[Callable[[ModelResponse, ModelRequestParameters], None]]:
396406
operation = 'chat'
397407
span_name = f'{operation} {self.model_name}'
398408
# TODO Missing attributes:
@@ -401,7 +411,7 @@ def _instrument(
401411
attributes: dict[str, AttributeValue] = {
402412
'gen_ai.operation.name': operation,
403413
**self.model_attributes(self.wrapped),
404-
'model_request_parameters': json.dumps(InstrumentedModel.serialize_any(model_request_parameters)),
414+
**self.model_request_parameters_attributes(model_request_parameters),
405415
'logfire.json_schema': json.dumps(
406416
{
407417
'type': 'object',
@@ -419,7 +429,7 @@ def _instrument(
419429
try:
420430
with self.instrumentation_settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
421431

422-
def finish(response: ModelResponse):
432+
def finish(response: ModelResponse, parameters: ModelRequestParameters):
423433
# FallbackModel updates these span attributes.
424434
attributes.update(getattr(span, 'attributes', {}))
425435
request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
@@ -443,7 +453,7 @@ def _record_metrics():
443453
if not span.is_recording():
444454
return
445455

446-
self.instrumentation_settings.handle_messages(messages, response, system, span)
456+
self.instrumentation_settings.handle_messages(messages, response, system, span, parameters)
447457

448458
attributes_to_set = {
449459
**response.usage.opentelemetry_attributes(),
@@ -476,7 +486,7 @@ def _record_metrics():
476486
record_metrics()
477487

478488
@staticmethod
479-
def model_attributes(model: Model):
489+
def model_attributes(model: Model) -> dict[str, AttributeValue]:
480490
attributes: dict[str, AttributeValue] = {
481491
GEN_AI_SYSTEM_ATTRIBUTE: model.system,
482492
GEN_AI_REQUEST_MODEL_ATTRIBUTE: model.model_name,
@@ -494,6 +504,12 @@ def model_attributes(model: Model):
494504

495505
return attributes
496506

507+
@staticmethod
508+
def model_request_parameters_attributes(
509+
model_request_parameters: ModelRequestParameters,
510+
) -> dict[str, AttributeValue]:
511+
return {'model_request_parameters': json.dumps(InstrumentedModel.serialize_any(model_request_parameters))}
512+
497513
@staticmethod
498514
def event_to_dict(event: Event) -> dict[str, Any]:
499515
if not event.body:

0 commit comments

Comments
 (0)