1111 contextmanager ,
1212 ExitStack ,
1313)
14- from types import TracebackType
14+ from types import EllipsisType , TracebackType
1515from typing import (
1616 Any ,
1717 ContextManager ,
8080 LMStudioCancelledError ,
8181 LMStudioClientError ,
8282 LMStudioPredictionError ,
83+ LMStudioTimeoutError ,
8384 LMStudioWebsocket ,
8485 LoadModelEndpoint ,
8586 ModelDownloadOptionBase ,
141142 "PredictionStream" ,
142143 "configure_default_client" ,
143144 "get_default_client" ,
145+ "get_sync_api_timeout" ,
144146 "embedding_model" ,
145147 "list_downloaded_models" ,
146148 "list_loaded_models" ,
147149 "llm" ,
148150 "prepare_image" ,
151+ "set_sync_api_timeout" ,
149152]
150153
154+ #
155+ _DEFAULT_TIMEOUT : float | None = 60.0
156+
157+
158+ def get_sync_api_timeout () -> float | None :
159+ """Return the current default sync API timeout when waiting for server messages."""
160+ return _DEFAULT_TIMEOUT
161+
162+
163+ def set_sync_api_timeout (timeout : float | None ) -> None :
164+ """Set the default sync API timeout when waiting for server messages."""
165+ global _DEFAULT_TIMEOUT
166+ if timeout is not None :
167+ timeout = float (timeout )
168+ _DEFAULT_TIMEOUT = timeout
169+
151170
152171T = TypeVar ("T" )
172+ CallWithTimeout : TypeAlias = Callable [[float | None ], Any ]
173+ TimeoutOption : TypeAlias = float | None | EllipsisType
153174
154175
155176class SyncChannel (Generic [T ]):
@@ -158,16 +179,18 @@ class SyncChannel(Generic[T]):
158179 def __init__ (
159180 self ,
160181 channel_id : int ,
161- get_message : Callable [[], Any ] ,
182+ get_message : CallWithTimeout ,
162183 endpoint : ChannelEndpoint [T , Any , Any ],
163184 send_json : Callable [[DictObject ], None ],
164185 log_context : LogEventContext ,
186+ timeout : TimeoutOption = ...,
165187 ) -> None :
166188 """Initialize synchronous websocket streaming channel."""
167189 self ._is_finished = False
168190 self ._get_message = get_message
169- self ._api_channel = ChannelHandler (channel_id , endpoint , log_context )
170191 self ._send_json = send_json
192+ self ._timeout = timeout
193+ self ._api_channel = ChannelHandler (channel_id , endpoint , log_context )
171194
172195 def get_creation_message (self ) -> DictObject :
173196 """Get the message to send to create this channel."""
@@ -185,6 +208,14 @@ def cancel(self) -> None:
185208 cancel_message = self ._api_channel .get_cancel_message ()
186209 self ._send_json (cancel_message )
187210
211+ @property
212+ def timeout (self ) -> float | None :
213+ """Permitted time between received messages for this channel."""
214+ timeout = self ._timeout
215+ if timeout is ...:
216+ return _DEFAULT_TIMEOUT
217+ return timeout
218+
188219 def rx_stream (
189220 self ,
190221 ) -> Iterator [DictObject | None ]:
@@ -193,7 +224,10 @@ def rx_stream(
193224 with sdk_public_api ():
194225 # Avoid emitting tracebacks that delve into supporting libraries
195226 # (we can't easily suppress the SDK's own frames for iterators)
196- message = self ._get_message ()
227+ try :
228+ message = self ._get_message (self .timeout )
229+ except TimeoutError :
230+ raise LMStudioTimeoutError from None
197231 contents = self ._api_channel .handle_rx_message (message )
198232 if contents is None :
199233 self ._is_finished = True
@@ -216,12 +250,14 @@ class SyncRemoteCall:
216250 def __init__ (
217251 self ,
218252 call_id : int ,
219- get_message : Callable [[], Any ] ,
253+ get_message : CallWithTimeout ,
220254 log_context : LogEventContext ,
221255 notice_prefix : str = "RPC" ,
256+ timeout : TimeoutOption = ...,
222257 ) -> None :
223258 """Initialize synchronous remote procedure call."""
224259 self ._get_message = get_message
260+ self ._timeout = timeout
225261 self ._rpc = RemoteCallHandler (call_id , log_context , notice_prefix )
226262 self ._logger = logger = new_logger (type (self ).__name__ )
227263 logger .update_context (log_context , call_id = call_id )
@@ -232,9 +268,20 @@ def get_rpc_message(
232268 """Get the message to send to initiate this remote procedure call."""
233269 return self ._rpc .get_rpc_message (endpoint , params )
234270
271+ @property
272+ def timeout (self ) -> float | None :
273+ """Permitted time to wait for a reply to this call."""
274+ timeout = self ._timeout
275+ if timeout is ...:
276+ return _DEFAULT_TIMEOUT
277+ return timeout
278+
235279 def receive_result (self ) -> Any :
236280 """Receive call response on the receive queue."""
237- message = self ._get_message ()
281+ try :
282+ message = self ._get_message (self .timeout )
283+ except TimeoutError :
284+ raise LMStudioTimeoutError from None
238285 return self ._rpc .handle_rx_message (message )
239286
240287
0 commit comments