44from contextlib import asynccontextmanager
55from dataclasses import dataclass
66from datetime import datetime
7- from typing import Any
7+ from typing import Any , cast
88
99from pydantic import ConfigDict , with_config
1010from temporalio import activity , workflow
3030@with_config (ConfigDict (arbitrary_types_allowed = True ))
3131class _RequestParams :
3232 messages : list [ModelMessage ]
33- model_settings : ModelSettings | None
33+ # `model_settings` can't be a `ModelSettings` because Temporal would end up dropping fields only defined on its subclasses.
34+ model_settings : dict [str , Any ] | None
3435 model_request_parameters : ModelRequestParameters
3536 serialized_run_context : Any
3637
@@ -82,7 +83,11 @@ def __init__(
8283
8384 @activity .defn (name = f'{ activity_name_prefix } __model_request' )
8485 async def request_activity (params : _RequestParams ) -> ModelResponse :
85- return await self .wrapped .request (params .messages , params .model_settings , params .model_request_parameters )
86+ return await self .wrapped .request (
87+ params .messages ,
88+ cast (ModelSettings | None , params .model_settings ),
89+ params .model_request_parameters ,
90+ )
8691
8792 self .request_activity = request_activity
8893
@@ -92,7 +97,10 @@ async def request_stream_activity(params: _RequestParams, deps: AgentDepsT) -> M
9297
9398 run_context = self .run_context_type .deserialize_run_context (params .serialized_run_context , deps = deps )
9499 async with self .wrapped .request_stream (
95- params .messages , params .model_settings , params .model_request_parameters , run_context
100+ params .messages ,
101+ cast (ModelSettings | None , params .model_settings ),
102+ params .model_request_parameters ,
103+ run_context ,
96104 ) as streamed_response :
97105 await self .event_stream_handler (run_context , streamed_response )
98106
@@ -124,7 +132,7 @@ async def request(
124132 activity = self .request_activity ,
125133 arg = _RequestParams (
126134 messages = messages ,
127- model_settings = model_settings ,
135+ model_settings = cast ( dict [ str , Any ] | None , model_settings ) ,
128136 model_request_parameters = model_request_parameters ,
129137 serialized_run_context = None ,
130138 ),
@@ -161,7 +169,7 @@ async def request_stream(
161169 args = [
162170 _RequestParams (
163171 messages = messages ,
164- model_settings = model_settings ,
172+ model_settings = cast ( dict [ str , Any ] | None , model_settings ) ,
165173 model_request_parameters = model_request_parameters ,
166174 serialized_run_context = serialized_run_context ,
167175 ),
0 commit comments