diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index ad2c831..80851ff 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -35,7 +35,7 @@ Chat, ChatHistoryDataDict, FileHandle, - _FileCacheInputType, + LocalFileInput, _LocalFileData, ) from .json_api import ( @@ -590,8 +590,8 @@ async def _fetch_file_handle(self, file_data: _LocalFileData) -> FileHandle: return load_struct(handle, FileHandle) @sdk_public_api_async() - async def _add_temp_file( - self, src: _FileCacheInputType, name: str | None = None + async def prepare_file( + self, src: LocalFileInput, name: str | None = None ) -> FileHandle: """Add a file to the server.""" # Private until LM Studio file handle support stabilizes @@ -1502,12 +1502,12 @@ def repository(self) -> AsyncSessionRepository: # Convenience methods @sdk_public_api_async() - async def _add_temp_file( - self, src: _FileCacheInputType, name: str | None = None + async def prepare_file( + self, src: LocalFileInput, name: str | None = None ) -> FileHandle: """Add a file to the server.""" # Private until LM Studio file handle support stabilizes - return await self._files._add_temp_file(src, name) + return await self._files.prepare_file(src, name) @sdk_public_api_async() async def list_downloaded_models( diff --git a/src/lmstudio/history.py b/src/lmstudio/history.py index 1913253..de2a58d 100644 --- a/src/lmstudio/history.py +++ b/src/lmstudio/history.py @@ -370,10 +370,8 @@ def add_user_message( self, content: UserMessageInput | Iterable[UserMessageInput], *, + files: Sequence[FileHandleInput] = (), images: Sequence[FileHandleInput] = (), - # Mark file parameters as private until LM Studio - # file handle support stabilizes - _files: Sequence[FileHandleInput] = (), ) -> UserMessage: """Add a new user message to the chat history.""" # Accept both singular and multi-part user messages @@ -383,10 +381,10 @@ def add_user_message( else: content_items = list(content) # Convert given local file information to file handles + if files: + content_items.extend(files) if images: content_items.extend(images) - if _files: - content_items.extend(_files) # Consecutive messages with the same role are not supported, # but multi-part user messages are valid (to allow for file # attachments), so just merge them @@ -519,14 +517,16 @@ def add_tool_result(self, result: ToolCallResultInput) -> ToolResultMessage: return message +LocalFileInput = BinaryIO | bytes | str | os.PathLike[str] + + # Private until file handle caching support is part of the published SDK API -_FileCacheInputType = BinaryIO | bytes | str | os.PathLike[str] -def _get_file_details(src: _FileCacheInputType) -> Tuple[str, bytes]: +def _get_file_details(src: LocalFileInput) -> Tuple[str, bytes]: """Read file contents as binary data and generate a suitable default name.""" if isinstance(src, bytes): - # We interpreter bytes as raw data, not a bytes filesystem path + # We process bytes as raw data, not a bytes filesystem path data = src name = str(uuid.uuid4()) elif hasattr(src, "read"): @@ -555,14 +555,13 @@ def _get_file_details(src: _FileCacheInputType) -> Tuple[str, bytes]: _FileHandleCacheKey: TypeAlias = tuple[str, _ContentHash] -# Private until file handle caching support is part of the published SDK API class _LocalFileData: """Local file data to be added to a chat history.""" name: str raw_data: bytes - def __init__(self, src: _FileCacheInputType, name: str | None = None) -> None: + def __init__(self, src: LocalFileInput, name: str | None = None) -> None: default_name, raw_data = _get_file_details(src) self.name = name or default_name self.raw_data = raw_data @@ -594,7 +593,7 @@ def __init__(self) -> None: @sdk_public_api() def _get_file_handle( - self, src: _FileCacheInputType, name: str | None = None + self, src: LocalFileInput, name: str | None = None ) -> FileHandle: file_data = _LocalFileData(src, name) cache_key = file_data._get_cache_key() diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index 358eda0..a1ec696 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -53,7 +53,7 @@ Chat, ChatHistoryDataDict, FileHandle, - _FileCacheInputType, + LocalFileInput, _LocalFileData, ToolCallRequest, ) @@ -765,9 +765,7 @@ def _fetch_file_handle(self, file_data: _LocalFileData) -> FileHandle: return load_struct(handle, FileHandle) @sdk_public_api() - def _add_temp_file( - self, src: _FileCacheInputType, name: str | None = None - ) -> FileHandle: + def prepare_file(self, src: LocalFileInput, name: str | None = None) -> FileHandle: """Add a file to the server.""" # Private until LM Studio file handle support stabilizes file_data = _LocalFileData(src, name) @@ -1820,12 +1818,10 @@ def repository(self) -> SyncSessionRepository: # Convenience methods @sdk_public_api() - def _add_temp_file( - self, src: _FileCacheInputType, name: str | None = None - ) -> FileHandle: + def prepare_file(self, src: LocalFileInput, name: str | None = None) -> FileHandle: """Add a file to the server.""" # Private until LM Studio file handle support stabilizes - return self._files._add_temp_file(src, name) + return self._files.prepare_file(src, name) @sdk_public_api() def list_downloaded_models( @@ -1895,10 +1891,10 @@ def embedding_model( @sdk_public_api() -def _add_temp_file(src: _FileCacheInputType, name: str | None = None) -> FileHandle: +def prepare_file(src: LocalFileInput, name: str | None = None) -> FileHandle: """Add a file to the server using the default global client.""" # Private until LM Studio file handle support stabilizes - return get_default_client()._add_temp_file(src, name) + return get_default_client().prepare_file(src, name) @sdk_public_api() diff --git a/tests/async/test_images_async.py b/tests/async/test_images_async.py index 69f178b..d9b6221 100644 --- a/tests/async/test_images_async.py +++ b/tests/async/test_images_async.py @@ -24,7 +24,7 @@ async def test_upload_from_pathlike_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) async with AsyncClient() as client: session = client._files - file = await session._add_temp_file(IMAGE_FILEPATH) + file = await session.prepare_file(IMAGE_FILEPATH) assert file assert isinstance(file, FileHandle) logging.info(f"Uploaded file: {file}") @@ -37,7 +37,7 @@ async def test_upload_from_file_obj_async(caplog: LogCap) -> None: async with AsyncClient() as client: session = client._files with open(IMAGE_FILEPATH, "rb") as f: - file = await session._add_temp_file(f) + file = await session.prepare_file(f) assert file assert isinstance(file, FileHandle) logging.info(f"Uploaded file: {file}") @@ -50,7 +50,7 @@ async def test_upload_from_bytesio_async(caplog: LogCap) -> None: async with AsyncClient() as client: session = client._files with open(IMAGE_FILEPATH, "rb") as f: - file = await session._add_temp_file(BytesIO(f.read())) + file = await session.prepare_file(BytesIO(f.read())) assert file assert isinstance(file, FileHandle) logging.info(f"Uploaded file: {file}") @@ -64,7 +64,7 @@ async def test_vlm_predict_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) model_id = EXPECTED_VLM_ID async with AsyncClient() as client: - file_handle = await client._files._add_temp_file(IMAGE_FILEPATH) + file_handle = await client._files.prepare_file(IMAGE_FILEPATH) history = Chat() history.add_user_message((prompt, file_handle)) vlm = await client.llm.model(model_id) @@ -84,7 +84,7 @@ async def test_non_vlm_predict_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) model_id = "hugging-quants/llama-3.2-1b-instruct" async with AsyncClient() as client: - file_handle = await client._files._add_temp_file(IMAGE_FILEPATH) + file_handle = await client._files.prepare_file(IMAGE_FILEPATH) history = Chat() history.add_user_message((prompt, file_handle)) llm = await client.llm.model(model_id) @@ -101,7 +101,7 @@ async def test_vlm_predict_image_param_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) model_id = EXPECTED_VLM_ID async with AsyncClient() as client: - file_handle = await client._files._add_temp_file(IMAGE_FILEPATH) + file_handle = await client._files.prepare_file(IMAGE_FILEPATH) history = Chat() history.add_user_message(prompt, images=[file_handle]) vlm = await client.llm.model(model_id) @@ -121,7 +121,7 @@ async def test_non_vlm_predict_image_param_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) model_id = "hugging-quants/llama-3.2-1b-instruct" async with AsyncClient() as client: - file_handle = await client._files._add_temp_file(IMAGE_FILEPATH) + file_handle = await client._files.prepare_file(IMAGE_FILEPATH) history = Chat() history.add_user_message(prompt, images=[file_handle]) llm = await client.llm.model(model_id) diff --git a/tests/sync/test_images_sync.py b/tests/sync/test_images_sync.py index 69298a7..845074d 100644 --- a/tests/sync/test_images_sync.py +++ b/tests/sync/test_images_sync.py @@ -30,7 +30,7 @@ def test_upload_from_pathlike_sync(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) with Client() as client: session = client._files - file = session._add_temp_file(IMAGE_FILEPATH) + file = session.prepare_file(IMAGE_FILEPATH) assert file assert isinstance(file, FileHandle) logging.info(f"Uploaded file: {file}") @@ -42,7 +42,7 @@ def test_upload_from_file_obj_sync(caplog: LogCap) -> None: with Client() as client: session = client._files with open(IMAGE_FILEPATH, "rb") as f: - file = session._add_temp_file(f) + file = session.prepare_file(f) assert file assert isinstance(file, FileHandle) logging.info(f"Uploaded file: {file}") @@ -54,7 +54,7 @@ def test_upload_from_bytesio_sync(caplog: LogCap) -> None: with Client() as client: session = client._files with open(IMAGE_FILEPATH, "rb") as f: - file = session._add_temp_file(BytesIO(f.read())) + file = session.prepare_file(BytesIO(f.read())) assert file assert isinstance(file, FileHandle) logging.info(f"Uploaded file: {file}") @@ -67,7 +67,7 @@ def test_vlm_predict_sync(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) model_id = EXPECTED_VLM_ID with Client() as client: - file_handle = client._files._add_temp_file(IMAGE_FILEPATH) + file_handle = client._files.prepare_file(IMAGE_FILEPATH) history = Chat() history.add_user_message((prompt, file_handle)) vlm = client.llm.model(model_id) @@ -86,7 +86,7 @@ def test_non_vlm_predict_sync(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) model_id = "hugging-quants/llama-3.2-1b-instruct" with Client() as client: - file_handle = client._files._add_temp_file(IMAGE_FILEPATH) + file_handle = client._files.prepare_file(IMAGE_FILEPATH) history = Chat() history.add_user_message((prompt, file_handle)) llm = client.llm.model(model_id) @@ -102,7 +102,7 @@ def test_vlm_predict_image_param_sync(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) model_id = EXPECTED_VLM_ID with Client() as client: - file_handle = client._files._add_temp_file(IMAGE_FILEPATH) + file_handle = client._files.prepare_file(IMAGE_FILEPATH) history = Chat() history.add_user_message(prompt, images=[file_handle]) vlm = client.llm.model(model_id) @@ -121,7 +121,7 @@ def test_non_vlm_predict_image_param_sync(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) model_id = "hugging-quants/llama-3.2-1b-instruct" with Client() as client: - file_handle = client._files._add_temp_file(IMAGE_FILEPATH) + file_handle = client._files.prepare_file(IMAGE_FILEPATH) history = Chat() history.add_user_message(prompt, images=[file_handle]) llm = client.llm.model(model_id) diff --git a/tests/test_convenience_api.py b/tests/test_convenience_api.py index e943bf1..77b3376 100644 --- a/tests/test_convenience_api.py +++ b/tests/test_convenience_api.py @@ -47,11 +47,11 @@ def test_embedding_specific() -> None: @pytest.mark.lmstudio -def test_add_temp_file() -> None: +def test_prepare_file() -> None: # API is private until LM Studio file handle support stabilizes name = "example-file" raw_data = b"raw data" - file_handle = lms.sync_api._add_temp_file(raw_data, name) + file_handle = lms.sync_api.prepare_file(raw_data, name) assert file_handle.name == name assert file_handle.size_bytes == len(raw_data) diff --git a/tests/test_history.py b/tests/test_history.py index bb98c2f..c73074c 100644 --- a/tests/test_history.py +++ b/tests/test_history.py @@ -15,7 +15,7 @@ AnyChatMessageDict, ChatHistoryData, ChatHistoryDataDict, - _FileCacheInputType, + LocalFileInput, FileHandle, _FileHandleCache, FileHandleDict, @@ -492,7 +492,7 @@ def _make_local_file_cache() -> tuple[_FileHandleCache, list[FileHandle], int]: # * files with different names are looked up under both names cache = _FileHandleCache() num_unique_files = 3 - files_to_cache: list[tuple[_FileCacheInputType, str | None]] = [ + files_to_cache: list[tuple[LocalFileInput, str | None]] = [ (b"raw binary data", "raw-binary.txt"), (b"raw binary data", "raw-binary.txt"), (IMAGE_FILEPATH, None), @@ -589,13 +589,8 @@ def test_invalid_local_file() -> None: "text": "What do you make of this?", "type": "text", }, - { - "fileType": "image", - "identifier": "some-image", - "name": "lemmy.png", - "sizeBytes": 41812, - "type": "file", - }, + # Implementation attaches the prepared file handles + # before it attaches the prepared image handles { "fileType": "text/plain", "identifier": "some-file", @@ -603,6 +598,13 @@ def test_invalid_local_file() -> None: "sizeBytes": 100, "type": "file", }, + { + "fileType": "image", + "identifier": "some-image", + "name": "lemmy.png", + "sizeBytes": 41812, + "type": "file", + }, ], "role": "user", }, @@ -621,7 +623,7 @@ def test_user_message_attachments() -> None: chat.add_user_message( "What do you make of this?", images=[INPUT_IMAGE_HANDLE], - _files=[INPUT_FILE_HANDLE], + files=[INPUT_FILE_HANDLE], ) history = chat._get_history() assert history["messages"] == EXPECTED_USER_ATTACHMENT_MESSAGES