4
4
from contextlib import asynccontextmanager
5
5
from dataclasses import dataclass
6
6
from datetime import datetime
7
- from typing import Any
7
+ from typing import Any , cast
8
8
9
9
from pydantic import ConfigDict , with_config
10
10
from temporalio import activity , workflow
30
30
@with_config (ConfigDict (arbitrary_types_allowed = True ))
31
31
class _RequestParams :
32
32
messages : list [ModelMessage ]
33
- model_settings : ModelSettings | None
33
+ # `model_settings` can't be a `ModelSettings` because Temporal would end up dropping fields only defined on its subclasses.
34
+ model_settings : dict [str , Any ] | None
34
35
model_request_parameters : ModelRequestParameters
35
36
serialized_run_context : Any
36
37
@@ -82,7 +83,11 @@ def __init__(
82
83
83
84
@activity .defn (name = f'{ activity_name_prefix } __model_request' )
84
85
async def request_activity (params : _RequestParams ) -> ModelResponse :
85
- return await self .wrapped .request (params .messages , params .model_settings , params .model_request_parameters )
86
+ return await self .wrapped .request (
87
+ params .messages ,
88
+ cast (ModelSettings | None , params .model_settings ),
89
+ params .model_request_parameters ,
90
+ )
86
91
87
92
self .request_activity = request_activity
88
93
@@ -92,7 +97,10 @@ async def request_stream_activity(params: _RequestParams, deps: AgentDepsT) -> M
92
97
93
98
run_context = self .run_context_type .deserialize_run_context (params .serialized_run_context , deps = deps )
94
99
async with self .wrapped .request_stream (
95
- params .messages , params .model_settings , params .model_request_parameters , run_context
100
+ params .messages ,
101
+ cast (ModelSettings | None , params .model_settings ),
102
+ params .model_request_parameters ,
103
+ run_context ,
96
104
) as streamed_response :
97
105
await self .event_stream_handler (run_context , streamed_response )
98
106
@@ -124,7 +132,7 @@ async def request(
124
132
activity = self .request_activity ,
125
133
arg = _RequestParams (
126
134
messages = messages ,
127
- model_settings = model_settings ,
135
+ model_settings = cast ( dict [ str , Any ] | None , model_settings ) ,
128
136
model_request_parameters = model_request_parameters ,
129
137
serialized_run_context = None ,
130
138
),
@@ -161,7 +169,7 @@ async def request_stream(
161
169
args = [
162
170
_RequestParams (
163
171
messages = messages ,
164
- model_settings = model_settings ,
172
+ model_settings = cast ( dict [ str , Any ] | None , model_settings ) ,
165
173
model_request_parameters = model_request_parameters ,
166
174
serialized_run_context = serialized_run_context ,
167
175
),
0 commit comments