Skip to content

Commit a2883d2

Browse files
authored
Fix TemporalAgent dropping model-specific ModelSettings (e.g. openai_reasoning_effort) (#2938)
1 parent aca70eb commit a2883d2

File tree

2 files changed

+57
-7
lines changed

2 files changed

+57
-7
lines changed

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from contextlib import asynccontextmanager
55
from dataclasses import dataclass
66
from datetime import datetime
7-
from typing import Any
7+
from typing import Any, cast
88

99
from pydantic import ConfigDict, with_config
1010
from temporalio import activity, workflow
@@ -30,7 +30,8 @@
3030
@with_config(ConfigDict(arbitrary_types_allowed=True))
3131
class _RequestParams:
3232
messages: list[ModelMessage]
33-
model_settings: ModelSettings | None
33+
# `model_settings` can't be a `ModelSettings` because Temporal would end up dropping fields only defined on its subclasses.
34+
model_settings: dict[str, Any] | None
3435
model_request_parameters: ModelRequestParameters
3536
serialized_run_context: Any
3637

@@ -82,7 +83,11 @@ def __init__(
8283

8384
@activity.defn(name=f'{activity_name_prefix}__model_request')
8485
async def request_activity(params: _RequestParams) -> ModelResponse:
85-
return await self.wrapped.request(params.messages, params.model_settings, params.model_request_parameters)
86+
return await self.wrapped.request(
87+
params.messages,
88+
cast(ModelSettings | None, params.model_settings),
89+
params.model_request_parameters,
90+
)
8691

8792
self.request_activity = request_activity
8893

@@ -92,7 +97,10 @@ async def request_stream_activity(params: _RequestParams, deps: AgentDepsT) -> M
9297

9398
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
9499
async with self.wrapped.request_stream(
95-
params.messages, params.model_settings, params.model_request_parameters, run_context
100+
params.messages,
101+
cast(ModelSettings | None, params.model_settings),
102+
params.model_request_parameters,
103+
run_context,
96104
) as streamed_response:
97105
await self.event_stream_handler(run_context, streamed_response)
98106

@@ -124,7 +132,7 @@ async def request(
124132
activity=self.request_activity,
125133
arg=_RequestParams(
126134
messages=messages,
127-
model_settings=model_settings,
135+
model_settings=cast(dict[str, Any] | None, model_settings),
128136
model_request_parameters=model_request_parameters,
129137
serialized_run_context=None,
130138
),
@@ -161,7 +169,7 @@ async def request_stream(
161169
args=[
162170
_RequestParams(
163171
messages=messages,
164-
model_settings=model_settings,
172+
model_settings=cast(dict[str, Any] | None, model_settings),
165173
model_request_parameters=model_request_parameters,
166174
serialized_run_context=serialized_run_context,
167175
),

tests/test_temporal.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytest
1313
from pydantic import BaseModel
1414

15-
from pydantic_ai import Agent, RunContext
15+
from pydantic_ai import Agent, ModelSettings, RunContext
1616
from pydantic_ai.direct import model_request_stream
1717
from pydantic_ai.exceptions import ApprovalRequired, CallDeferred, ModelRetry, UserError
1818
from pydantic_ai.messages import (
@@ -33,6 +33,7 @@
3333
UserPromptPart,
3434
)
3535
from pydantic_ai.models import Model, cached_async_http_client
36+
from pydantic_ai.models.function import AgentInfo, FunctionModel
3637
from pydantic_ai.run import AgentRunResult
3738
from pydantic_ai.tools import DeferredToolRequests, DeferredToolResults, ToolDefinition
3839
from pydantic_ai.toolsets import ExternalToolset, FunctionToolset
@@ -1931,3 +1932,44 @@ async def test_temporal_agent_with_model_retry(allow_model_requests: None, clien
19311932
),
19321933
]
19331934
)
1935+
1936+
1937+
class CustomModelSettings(ModelSettings, total=False):
1938+
custom_setting: str
1939+
1940+
1941+
def return_settings(messages: list[ModelMessage], agent_info: AgentInfo) -> ModelResponse:
1942+
return ModelResponse(parts=[TextPart(str(agent_info.model_settings))])
1943+
1944+
1945+
model_settings = CustomModelSettings(max_tokens=123, custom_setting='custom_value')
1946+
model = FunctionModel(return_settings, settings=model_settings)
1947+
1948+
settings_agent = Agent(model, name='settings_agent')
1949+
1950+
# This needs to be done before the `TemporalAgent` is bound to the workflow.
1951+
settings_temporal_agent = TemporalAgent(settings_agent, activity_config=BASE_ACTIVITY_CONFIG)
1952+
1953+
1954+
@workflow.defn
1955+
class SettingsAgentWorkflow:
1956+
@workflow.run
1957+
async def run(self, prompt: str) -> str:
1958+
result = await settings_temporal_agent.run(prompt)
1959+
return result.output
1960+
1961+
1962+
async def test_custom_model_settings(allow_model_requests: None, client: Client):
1963+
async with Worker(
1964+
client,
1965+
task_queue=TASK_QUEUE,
1966+
workflows=[SettingsAgentWorkflow],
1967+
plugins=[AgentPlugin(settings_temporal_agent)],
1968+
):
1969+
output = await client.execute_workflow( # pyright: ignore[reportUnknownMemberType]
1970+
SettingsAgentWorkflow.run,
1971+
args=['Give me those settings'],
1972+
id=SettingsAgentWorkflow.__name__,
1973+
task_queue=TASK_QUEUE,
1974+
)
1975+
assert output == snapshot("{'max_tokens': 123, 'custom_setting': 'custom_value'}")

0 commit comments

Comments
 (0)