diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index 1b03ecc..028ced4 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -32,7 +32,6 @@ from .sdk_api import LMStudioRuntimeError, sdk_public_api, sdk_public_api_async from .schemas import AnyLMStudioStruct, DictObject, DictSchema, ModelSchema from .history import ( - AssistantResponse, Chat, ChatHistoryDataDict, _FileHandle, @@ -50,6 +49,8 @@ CompletionEndpoint, DEFAULT_TTL, DownloadedModelBase, + DownloadFinalizedCallback, + DownloadProgressCallback, EmbeddingLoadModelConfig, EmbeddingLoadModelConfigDict, EmbeddingModelInfo, @@ -65,16 +66,21 @@ LMStudioWebsocket, LMStudioWebsocketError, LoadModelEndpoint, - ModelInstanceInfo, ModelDownloadOptionBase, ModelHandleBase, + ModelInstanceInfo, + ModelLoadingCallback, ModelSessionTypes, ModelTypesEmbedding, ModelTypesLlm, PredictionStreamBase, PredictionEndpoint, + PredictionFirstTokenCallback, + PredictionFragmentCallback, PredictionFragmentEvent, + PredictionMessageCallback, PredictionResult, + PromptProcessingCallback, RemoteCallHandler, TModelInfo, TPrediction, @@ -84,7 +90,6 @@ ) from ._kv_config import TLoadConfig, TLoadConfigDict, dict_from_fields_key from ._sdk_models import ( - DownloadProgressUpdate, EmbeddingRpcEmbedStringParameter, EmbeddingRpcTokenizeParameter, LlmApplyPromptTemplateOpts, @@ -481,7 +486,7 @@ async def load_new_instance( ttl: int | None = DEFAULT_TTL, instance_identifier: str | None = None, config: TLoadConfig | TLoadConfigDict | None = None, - on_load_progress: Callable[[float], None] | None = None, + on_load_progress: ModelLoadingCallback | None = None, ) -> TAsyncModelHandle: """Load this model with the given identifier and configuration.""" handle: TAsyncModelHandle = await self._session._load_new_instance( @@ -495,7 +500,7 @@ async def model( *, ttl: int | None = DEFAULT_TTL, config: TLoadConfig | TLoadConfigDict | None = None, - on_load_progress: Callable[[float], None] | None = None, + on_load_progress: ModelLoadingCallback | None = None, ) -> TAsyncModelHandle: # Call _get_or_load directly, since we have a model identifier handle: TAsyncModelHandle = await self._session._get_or_load( @@ -597,14 +602,11 @@ async def _add_temp_file( class AsyncModelDownloadOption(ModelDownloadOptionBase[AsyncSession]): """A single download option for a model search result.""" - # We prefer using DownloadProgressUpdate as a parameter type because then - # if we update the schema down the line, no need to rewrite callback signatures. - # Plus, Callable[[int, int, float], None] is not very descriptive. @sdk_public_api_async() async def download( self, - on_progress: Callable[[DownloadProgressUpdate], None] | None = None, - on_finalize: Callable[[], None] | None = None, + on_progress: DownloadProgressCallback | None = None, + on_finalize: DownloadFinalizedCallback | None = None, ) -> str: """Download a model and get its path for loading.""" endpoint = self._get_download_endpoint(on_progress, on_finalize) @@ -747,7 +749,7 @@ async def model( *, ttl: int | None = DEFAULT_TTL, config: TLoadConfig | TLoadConfigDict | None = None, - on_load_progress: Callable[[float], None] | None = None, + on_load_progress: ModelLoadingCallback | None = None, ) -> TAsyncModelHandle: """Get a handle to the specified model (loading it if necessary).""" if model_key is None: @@ -777,7 +779,7 @@ async def load_new_instance( *, ttl: int | None = DEFAULT_TTL, config: TLoadConfig | TLoadConfigDict | None = None, - on_load_progress: Callable[[float], None] | None = None, + on_load_progress: ModelLoadingCallback | None = None, ) -> TAsyncModelHandle: """Load the specified model with the given identifier and configuration.""" return await self._load_new_instance( @@ -790,7 +792,7 @@ async def _load_new_instance( instance_identifier: str | None, ttl: int | None, config: TLoadConfig | TLoadConfigDict | None, - on_load_progress: Callable[[float], None] | None, + on_load_progress: ModelLoadingCallback | None, ) -> TAsyncModelHandle: channel_type = self._API_TYPES.REQUEST_NEW_INSTANCE config_type = self._API_TYPES.MODEL_LOAD_CONFIG @@ -812,7 +814,7 @@ async def _get_or_load( model_key: str, ttl: int | None, config: TLoadConfig | TLoadConfigDict | None, - on_load_progress: Callable[[float], None] | None, + on_load_progress: ModelLoadingCallback | None, ) -> TAsyncModelHandle: """Load the specified model with the given identifier and configuration.""" channel_type = self._API_TYPES.REQUEST_GET_OR_LOAD @@ -966,10 +968,10 @@ async def _complete_stream( *, response_format: Literal[None] = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> AsyncPredictionStream[str]: ... @overload async def _complete_stream( @@ -979,10 +981,10 @@ async def _complete_stream( *, response_format: Type[ModelSchema] | DictSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> AsyncPredictionStream[DictObject]: ... async def _complete_stream( self, @@ -991,10 +993,10 @@ async def _complete_stream( *, response_format: Type[ModelSchema] | DictSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_message: PredictionMessageCallback | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, ) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]: """Request a one-off prediction without any context and stream the generated tokens.""" endpoint = CompletionEndpoint( @@ -1019,10 +1021,10 @@ async def _respond_stream( *, response_format: Literal[None] = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> AsyncPredictionStream[str]: ... @overload async def _respond_stream( @@ -1032,10 +1034,10 @@ async def _respond_stream( *, response_format: Type[ModelSchema] | DictSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> AsyncPredictionStream[DictObject]: ... async def _respond_stream( self, @@ -1043,11 +1045,11 @@ async def _respond_stream( history: Chat | ChatHistoryDataDict | str, *, response_format: Type[ModelSchema] | DictSchema | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, + on_message: PredictionMessageCallback | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, ) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]: """Request a response in an ongoing assistant chat session and stream the generated tokens.""" if not isinstance(history, Chat): @@ -1186,10 +1188,10 @@ async def complete_stream( *, response_format: Literal[None] = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> AsyncPredictionStream[str]: ... @overload async def complete_stream( @@ -1198,10 +1200,10 @@ async def complete_stream( *, response_format: Type[ModelSchema] | DictSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> AsyncPredictionStream[DictObject]: ... @sdk_public_api_async() async def complete_stream( @@ -1210,10 +1212,10 @@ async def complete_stream( *, response_format: Type[ModelSchema] | DictSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_message: PredictionMessageCallback | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, ) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]: """Request a one-off prediction without any context and stream the generated tokens.""" return await self._session._complete_stream( @@ -1234,10 +1236,10 @@ async def complete( *, response_format: Literal[None] = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionResult[str]: ... @overload async def complete( @@ -1246,10 +1248,10 @@ async def complete( *, response_format: Type[ModelSchema] | DictSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionResult[DictObject]: ... @sdk_public_api_async() async def complete( @@ -1258,10 +1260,10 @@ async def complete( *, response_format: Type[ModelSchema] | DictSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_message: PredictionMessageCallback | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, ) -> PredictionResult[str] | PredictionResult[DictObject]: """Request a one-off prediction without any context.""" prediction_stream = await self._session._complete_stream( @@ -1287,10 +1289,10 @@ async def respond_stream( *, response_format: Literal[None] = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> AsyncPredictionStream[str]: ... @overload async def respond_stream( @@ -1299,10 +1301,10 @@ async def respond_stream( *, response_format: Type[ModelSchema] | DictSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> AsyncPredictionStream[DictObject]: ... @sdk_public_api_async() async def respond_stream( @@ -1311,10 +1313,10 @@ async def respond_stream( *, response_format: Type[ModelSchema] | DictSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_message: PredictionMessageCallback | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, ) -> AsyncPredictionStream[str] | AsyncPredictionStream[DictObject]: """Request a response in an ongoing assistant chat session and stream the generated tokens.""" return await self._session._respond_stream( @@ -1335,10 +1337,10 @@ async def respond( *, response_format: Literal[None] = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionResult[str]: ... @overload async def respond( @@ -1347,10 +1349,10 @@ async def respond( *, response_format: Type[ModelSchema] | DictSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionResult[DictObject]: ... @sdk_public_api_async() async def respond( @@ -1359,10 +1361,10 @@ async def respond( *, response_format: Type[ModelSchema] | DictSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_message: PredictionMessageCallback | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, ) -> PredictionResult[str] | PredictionResult[DictObject]: """Request a response in an ongoing assistant chat session.""" prediction_stream = await self._session._respond_stream( diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index d57037c..d6c954e 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -125,6 +125,9 @@ __all__ = [ "ActResult", "AnyModelSpecifier", + "DownloadFinalizedCallback", + "DownloadProgressCallback", + "DownloadProgressUpdate", "EmbeddingModelInfo", "EmbeddingModelInstanceInfo", "EmbeddingLoadModelConfig", @@ -152,8 +155,12 @@ "ModelSpecifierDict", "ModelQuery", "ModelQueryDict", + "PredictionFirstTokenCallback", + "PredictionFragmentCallback", + "PredictionMessageCallback", "PredictionResult", "PredictionRoundResult", + "PromptProcessingCallback", "SerializedLMSExtendedError", "ToolFunctionDef", "ToolFunctionDefDict", @@ -712,6 +719,9 @@ class ModelDownloadFinalizeEvent(ChannelRxEvent[None]): ModelDownloadProgressEvent | ModelDownloadFinalizeEvent | ChannelCommonRxEvent ) +DownloadProgressCallback: TypeAlias = Callable[[DownloadProgressUpdate], Any] +DownloadFinalizedCallback: TypeAlias = Callable[[], Any] + class ModelDownloadEndpoint( ChannelEndpoint[str, ModelDownloadRxEvent, DownloadModelChannelRequestDict] @@ -724,8 +734,8 @@ class ModelDownloadEndpoint( def __init__( self, download_identifier: str, - on_progress: Callable[[DownloadProgressUpdate], None] | None = None, - on_finalize: Callable[[], None] | None = None, + on_progress: DownloadProgressCallback | None = None, + on_finalize: DownloadFinalizedCallback | None = None, ) -> None: params = DownloadModelChannelRequest._from_api_dict( {"downloadIdentifier": download_identifier} @@ -803,6 +813,8 @@ class ModelLoadingProgressEvent(ChannelRxEvent[float]): ModelLoadingRxEvent: TypeAlias = ModelLoadingProgressEvent | ChannelCommonRxEvent +ModelLoadingCallback: TypeAlias = Callable[[float], Any] + class _ModelLoadingEndpoint( ChannelEndpoint[ModelLoadResult, ModelLoadingRxEvent, TWireFormat] @@ -811,7 +823,7 @@ def __init__( self, model_key: str, creation_params: LMStudioStruct[TWireFormat] | DictObject, - on_load_progress: Callable[[float], None] | None = None, + on_load_progress: ModelLoadingCallback | None = None, ) -> None: super().__init__(creation_params) self._logger.update_context(model_key=model_key) @@ -902,7 +914,7 @@ def __init__( creation_param_type: Type[LoadModelChannelRequest], config_type: Type[TLoadConfig], config: TLoadConfig | TLoadConfigDict | None, - on_load_progress: Callable[[float], None] | None, + on_load_progress: ModelLoadingCallback | None, ) -> None: """Load the specified model with the given identifier and configuration.""" kv_config = load_config_to_kv_config_stack(config, config_type) @@ -934,7 +946,7 @@ def __init__( creation_param_type: Type[GetOrLoadChannelRequest], config_type: Type[TLoadConfig], config: TLoadConfig | TLoadConfigDict | None = None, - on_load_progress: Callable[[float], None] | None = None, + on_load_progress: ModelLoadingCallback | None = None, ) -> None: """Get the specified model, loading with given configuration if necessary.""" kv_config = load_config_to_kv_config_stack(config, config_type) @@ -1052,6 +1064,11 @@ class PredictionToolCallAbortedEvent(ChannelRxEvent[None]): ClientToolSpec: TypeAlias = tuple[type[Struct], Callable[..., Any]] ClientToolMap: TypeAlias = Mapping[str, ClientToolSpec] +PredictionMessageCallback: TypeAlias = Callable[[AssistantResponse], Any] +PredictionFirstTokenCallback: TypeAlias = Callable[[], Any] +PredictionFragmentCallback: TypeAlias = Callable[[LlmPredictionFragment], Any] +PromptProcessingCallback: TypeAlias = Callable[[float], Any] + class PredictionEndpoint( Generic[TPrediction], @@ -1069,10 +1086,10 @@ def __init__( history: Chat, response_format: Type[ModelSchema] | DictSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_message: PredictionMessageCallback | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, # The remaining options are only relevant for multi-round tool actions handle_invalid_tool_request: Callable[ [LMStudioPredictionError, _ToolCallRequest | None], str @@ -1320,10 +1337,10 @@ def __init__( prompt: str, response_format: Type[ModelSchema] | DictSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_message: PredictionMessageCallback | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, ) -> None: """Load the specified model with the given identifier and configuration.""" history = Chat() @@ -1888,8 +1905,8 @@ def info(self) -> ModelSearchResultDownloadOptionData: def _get_download_endpoint( self, - on_progress: Callable[[DownloadProgressUpdate], None] | None = None, - on_finalize: Callable[[], None] | None = None, + on_progress: DownloadProgressCallback | None = None, + on_finalize: DownloadFinalizedCallback | None = None, ) -> ModelDownloadEndpoint: # Throw a more specific exception than the one thrown by remote_call data = self._data diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index ddbd899..972c2b8 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -64,6 +64,8 @@ CompletionEndpoint, DEFAULT_TTL, DownloadedModelBase, + DownloadFinalizedCallback, + DownloadProgressCallback, EmbeddingLoadModelConfig, EmbeddingLoadModelConfigDict, EmbeddingModelInfo, @@ -80,19 +82,24 @@ LMStudioWebsocket, LMStudioWebsocketError, LoadModelEndpoint, - ModelInstanceInfo, ModelDownloadOptionBase, ModelHandleBase, + ModelInstanceInfo, + ModelLoadingCallback, ModelSessionTypes, ModelTypesEmbedding, ModelTypesLlm, PredictionEndpoint, + PredictionFirstTokenCallback, + PredictionFragmentCallback, PredictionFragmentEvent, + PredictionMessageCallback, PredictionResult, PredictionRoundResult, PredictionRxEvent, PredictionStreamBase, PredictionToolCallEvent, + PromptProcessingCallback, RemoteCallHandler, TModelInfo, TPrediction, @@ -104,7 +111,6 @@ ) from ._kv_config import TLoadConfig, TLoadConfigDict, dict_from_fields_key from ._sdk_models import ( - DownloadProgressUpdate, EmbeddingRpcEmbedStringParameter, EmbeddingRpcTokenizeParameter, LlmApplyPromptTemplateOpts, @@ -654,7 +660,7 @@ def load_new_instance( ttl: int | None = DEFAULT_TTL, instance_identifier: str | None = None, config: TLoadConfig | TLoadConfigDict | None = None, - on_load_progress: Callable[[float], None] | None = None, + on_load_progress: ModelLoadingCallback | None = None, ) -> TModelHandle: """Load this model with the given identifier and configuration.""" handle: TModelHandle = self._session._load_new_instance( @@ -668,7 +674,7 @@ def model( *, ttl: int | None = DEFAULT_TTL, config: TLoadConfig | TLoadConfigDict | None = None, - on_load_progress: Callable[[float], None] | None = None, + on_load_progress: ModelLoadingCallback | None = None, ) -> TModelHandle: # Call _get_or_load directly, since we have a model identifier handle: TModelHandle = self._session._get_or_load( @@ -766,14 +772,11 @@ def _add_temp_file( class ModelDownloadOption(ModelDownloadOptionBase[SyncSession]): """A single download option for a model search result.""" - # We prefer using DownloadProgressUpdate as a parameter type because then - # if we update the schema down the line, no need to rewrite callback signatures. - # Plus, Callable[[int, int, float], None] is not very descriptive. @sdk_public_api() def download( self, - on_progress: Callable[[DownloadProgressUpdate], None] | None = None, - on_finalize: Callable[[], None] | None = None, + on_progress: DownloadProgressCallback | None = None, + on_finalize: DownloadFinalizedCallback | None = None, ) -> str: """Download a model and get its path for loading.""" endpoint = self._get_download_endpoint(on_progress, on_finalize) @@ -904,7 +907,7 @@ def model( *, ttl: int | None = DEFAULT_TTL, config: TLoadConfig | TLoadConfigDict | None = None, - on_load_progress: Callable[[float], None] | None = None, + on_load_progress: ModelLoadingCallback | None = None, ) -> TModelHandle: """Get a handle to the specified model (loading it if necessary).""" if model_key is None: @@ -934,7 +937,7 @@ def load_new_instance( *, ttl: int | None = DEFAULT_TTL, config: TLoadConfig | TLoadConfigDict | None = None, - on_load_progress: Callable[[float], None] | None = None, + on_load_progress: ModelLoadingCallback | None = None, ) -> TModelHandle: """Load the specified model with the given identifier and configuration.""" return self._load_new_instance( @@ -947,7 +950,7 @@ def _load_new_instance( instance_identifier: str | None, ttl: int | None, config: TLoadConfig | TLoadConfigDict | None, - on_load_progress: Callable[[float], None] | None, + on_load_progress: ModelLoadingCallback | None, ) -> TModelHandle: channel_type = self._API_TYPES.REQUEST_NEW_INSTANCE config_type = self._API_TYPES.MODEL_LOAD_CONFIG @@ -969,7 +972,7 @@ def _get_or_load( model_key: str, ttl: int | None, config: TLoadConfig | TLoadConfigDict | None, - on_load_progress: Callable[[float], None] | None, + on_load_progress: ModelLoadingCallback | None, ) -> TModelHandle: """Get the specified model if it is already loaded, otherwise load it.""" channel_type = self._API_TYPES.REQUEST_GET_OR_LOAD @@ -1123,10 +1126,10 @@ def _complete_stream( *, response_format: Literal[None] = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionStream[str]: ... @overload def _complete_stream( @@ -1136,10 +1139,10 @@ def _complete_stream( *, response_format: Type[ModelSchema] | DictSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionStream[DictObject]: ... def _complete_stream( self, @@ -1148,10 +1151,10 @@ def _complete_stream( *, response_format: Type[ModelSchema] | DictSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_message: PredictionMessageCallback | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, ) -> PredictionStream[str] | PredictionStream[DictObject]: """Request a one-off prediction without any context and stream the generated tokens.""" endpoint = CompletionEndpoint( @@ -1176,10 +1179,10 @@ def _respond_stream( *, response_format: Literal[None] = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionStream[str]: ... @overload def _respond_stream( @@ -1189,10 +1192,10 @@ def _respond_stream( *, response_format: Type[ModelSchema] | DictSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionStream[DictObject]: ... def _respond_stream( self, @@ -1201,10 +1204,10 @@ def _respond_stream( *, response_format: Type[ModelSchema] | DictSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_message: PredictionMessageCallback | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, ) -> PredictionStream[str] | PredictionStream[DictObject]: """Request a response in an ongoing assistant chat session and stream the generated tokens.""" if not isinstance(history, Chat): @@ -1341,10 +1344,10 @@ def complete_stream( *, response_format: Literal[None] = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionStream[str]: ... @overload def complete_stream( @@ -1353,10 +1356,10 @@ def complete_stream( *, response_format: Type[ModelSchema] | DictSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionStream[DictObject]: ... @sdk_public_api() def complete_stream( @@ -1365,10 +1368,10 @@ def complete_stream( *, response_format: Type[ModelSchema] | DictSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_message: PredictionMessageCallback | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, ) -> PredictionStream[str] | PredictionStream[DictObject]: """Request a one-off prediction without any context and stream the generated tokens.""" return self._session._complete_stream( @@ -1389,10 +1392,10 @@ def complete( *, response_format: Literal[None] = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionResult[str]: ... @overload def complete( @@ -1401,10 +1404,10 @@ def complete( *, response_format: Type[ModelSchema] | DictSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionResult[DictObject]: ... @sdk_public_api() def complete( @@ -1413,10 +1416,10 @@ def complete( *, response_format: Type[ModelSchema] | DictSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_message: PredictionMessageCallback | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, ) -> PredictionResult[str] | PredictionResult[DictObject]: """Request a one-off prediction without any context.""" prediction_stream = self._session._complete_stream( @@ -1442,10 +1445,10 @@ def respond_stream( *, response_format: Literal[None] = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionStream[str]: ... @overload def respond_stream( @@ -1454,10 +1457,10 @@ def respond_stream( *, response_format: Type[ModelSchema] | DictSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionStream[DictObject]: ... @sdk_public_api() def respond_stream( @@ -1466,10 +1469,10 @@ def respond_stream( *, response_format: Type[ModelSchema] | DictSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_message: PredictionMessageCallback | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, ) -> PredictionStream[str] | PredictionStream[DictObject]: """Request a response in an ongoing assistant chat session and stream the generated tokens.""" return self._session._respond_stream( @@ -1490,10 +1493,10 @@ def respond( *, response_format: Literal[None] = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionResult[str]: ... @overload def respond( @@ -1502,10 +1505,10 @@ def respond( *, response_format: Type[ModelSchema] | DictSchema = ..., config: LlmPredictionConfig | LlmPredictionConfigDict | None = ..., - on_message: Callable[[AssistantResponse], None] | None = ..., - on_first_token: Callable[[], None] | None = ..., - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = ..., - on_prompt_processing_progress: Callable[[float], None] | None = ..., + on_message: PredictionMessageCallback | None = ..., + on_first_token: PredictionFirstTokenCallback | None = ..., + on_prediction_fragment: PredictionFragmentCallback | None = ..., + on_prompt_processing_progress: PromptProcessingCallback | None = ..., ) -> PredictionResult[DictObject]: ... @sdk_public_api() def respond( @@ -1514,10 +1517,10 @@ def respond( *, response_format: Type[ModelSchema] | DictSchema | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse], None] | None = None, - on_first_token: Callable[[], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None, - on_prompt_processing_progress: Callable[[float], None] | None = None, + on_message: PredictionMessageCallback | None = None, + on_first_token: PredictionFirstTokenCallback | None = None, + on_prediction_fragment: PredictionFragmentCallback | None = None, + on_prompt_processing_progress: PromptProcessingCallback | None = None, ) -> PredictionResult[str] | PredictionResult[DictObject]: """Request a response in an ongoing assistant chat session.""" prediction_stream = self._session._respond_stream( @@ -1538,6 +1541,7 @@ def respond( # Multi-round predictions are currently a sync-only handle-only feature # TODO: Refactor to allow for more code sharing with the async API + # with defined aliases for the expected callback signatures @sdk_public_api() def act( self, @@ -1546,15 +1550,15 @@ def act( *, max_prediction_rounds: int | None = None, config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, - on_message: Callable[[AssistantResponse | ToolResultMessage], None] + on_message: Callable[[AssistantResponse | ToolResultMessage], Any] | None = None, - on_first_token: Callable[[int], None] | None = None, - on_prediction_fragment: Callable[[LlmPredictionFragment, int], None] + on_first_token: Callable[[int], Any] | None = None, + on_prediction_fragment: Callable[[LlmPredictionFragment, int], Any] | None = None, - on_round_start: Callable[[int], None] | None = None, - on_round_end: Callable[[int], None] | None = None, - on_prediction_completed: Callable[[PredictionRoundResult], None] | None = None, - on_prompt_processing_progress: Callable[[float, int], None] | None = None, + on_round_start: Callable[[int], Any] | None = None, + on_round_end: Callable[[int], Any] | None = None, + on_prediction_completed: Callable[[PredictionRoundResult], Any] | None = None, + on_prompt_processing_progress: Callable[[float, int], Any] | None = None, handle_invalid_tool_request: Callable[ [LMStudioPredictionError, _ToolCallRequest | None], str ] @@ -1585,7 +1589,7 @@ def act( del tools # Supply the round index to any endpoint callbacks that expect one round_index: int - on_first_token_for_endpoint: Callable[[], None] | None = None + on_first_token_for_endpoint: PredictionFirstTokenCallback | None = None if on_first_token is not None: def _wrapped_on_first_token() -> None: @@ -1593,9 +1597,7 @@ def _wrapped_on_first_token() -> None: on_first_token(round_index) on_first_token_for_endpoint = _wrapped_on_first_token - on_prediction_fragment_for_endpoint: ( - Callable[[LlmPredictionFragment], None] | None - ) = None + on_prediction_fragment_for_endpoint: PredictionFragmentCallback | None = None if on_prediction_fragment is not None: def _wrapped_on_prediction_fragment( @@ -1605,7 +1607,7 @@ def _wrapped_on_prediction_fragment( on_prediction_fragment(fragment, round_index) on_prediction_fragment_for_endpoint = _wrapped_on_prediction_fragment - on_prompt_processing_for_endpoint: Callable[[float], None] | None = None + on_prompt_processing_for_endpoint: PromptProcessingCallback | None = None if on_prompt_processing_progress is not None: def _wrapped_on_prompt_processing_progress(progress: float) -> None: