diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index a9912c4..462e48c 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -96,6 +96,7 @@ ) from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config from ._sdk_models import ( + EmbeddingRpcCountTokensParameter, EmbeddingRpcEmbedStringParameter, EmbeddingRpcTokenizeParameter, LlmApplyPromptTemplateOpts, @@ -733,6 +734,18 @@ async def _get_context_length(self, model_specifier: AnyModelSpecifier) -> int: raw_model_info = await self._get_api_model_info(model_specifier) return int(raw_model_info.get("contextLength", -1)) + async def _count_tokens( + self, model_specifier: AnyModelSpecifier, input: str + ) -> int: + params = EmbeddingRpcCountTokensParameter._from_api_dict( + { + "specifier": _model_spec_to_api_dict(model_specifier), + "inputString": input, + } + ) + response = await self.remote_call("countTokens", params) + return int(response["tokenCount"]) + # Private helper method to allow the main API to easily accept iterables async def _tokenize_text( self, model_specifier: AnyModelSpecifier, input: str @@ -748,7 +761,6 @@ async def _tokenize_text( # Alas, type hints don't properly support distinguishing str vs Iterable[str]: # https://github.com/python/typing/issues/256 - @sdk_public_api_async() async def _tokenize( self, model_specifier: AnyModelSpecifier, input: str | Iterable[str] ) -> Sequence[int] | Sequence[Sequence[int]]: @@ -1191,6 +1203,11 @@ async def tokenize( """Tokenize the input string(s) using this model.""" return await self._session._tokenize(self.identifier, input) + @sdk_public_api_async() + async def count_tokens(self, input: str) -> int: + """Report the number of tokens needed for the input string using this model.""" + return await self._session._count_tokens(self.identifier, input) + @sdk_public_api_async() async def get_context_length(self) -> int: """Get the context length of this model.""" diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index ffa196e..5ce86a4 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -124,6 +124,7 @@ ) from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config from ._sdk_models import ( + EmbeddingRpcCountTokensParameter, EmbeddingRpcEmbedStringParameter, EmbeddingRpcTokenizeParameter, LlmApplyPromptTemplateOpts, @@ -902,6 +903,16 @@ def _get_context_length(self, model_specifier: AnyModelSpecifier) -> int: raw_model_info = self._get_api_model_info(model_specifier) return int(raw_model_info.get("contextLength", -1)) + def _count_tokens(self, model_specifier: AnyModelSpecifier, input: str) -> int: + params = EmbeddingRpcCountTokensParameter._from_api_dict( + { + "specifier": _model_spec_to_api_dict(model_specifier), + "inputString": input, + } + ) + response = self.remote_call("countTokens", params) + return int(response["tokenCount"]) + # Private helper method to allow the main API to easily accept iterables def _tokenize_text( self, model_specifier: AnyModelSpecifier, input: str @@ -1353,6 +1364,11 @@ def tokenize( """Tokenize the input string(s) using this model.""" return self._session._tokenize(self.identifier, input) + @sdk_public_api() + def count_tokens(self, input: str) -> int: + """Report the number of tokens needed for the input string using this model.""" + return self._session._count_tokens(self.identifier, input) + @sdk_public_api() def get_context_length(self) -> int: """Get the context length of this model.""" diff --git a/tests/async/test_embedding_async.py b/tests/async/test_embedding_async.py index 482f217..53e52cf 100644 --- a/tests/async/test_embedding_async.py +++ b/tests/async/test_embedding_async.py @@ -63,11 +63,15 @@ async def test_tokenize_async(model_id: str, caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) async with AsyncClient() as client: - session = client.embedding - response = await session._tokenize(model_id, input=text) + model = await client.embedding.model(model_id) + num_tokens = await model.count_tokens(text) + response = await model.tokenize(text) logging.info(f"Tokenization response: {response}") assert response assert isinstance(response, list) + # Ensure token count and tokenization are consistent + # (embedding models add extra start/end markers during actual tokenization) + assert len(response) == num_tokens + 2 # the response should be deterministic if we set constant seed # so we can also check the value if desired @@ -80,8 +84,8 @@ async def test_tokenize_list_async(model_id: str, caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) async with AsyncClient() as client: - session = client.embedding - response = await session._tokenize(model_id, input=text) + model = await client.embedding.model(model_id) + response = await model.tokenize(text) logging.info(f"Tokenization response: {response}") assert response assert isinstance(response, list) @@ -142,6 +146,10 @@ async def test_invalid_model_request_async(caplog: LogCap) -> None: with pytest.raises(LMStudioModelNotFoundError) as exc_info: await model.embed("Some text") check_sdk_error(exc_info, __file__) + with anyio.fail_after(30): + with pytest.raises(LMStudioModelNotFoundError) as exc_info: + await model.count_tokens("Some text") + check_sdk_error(exc_info, __file__) with anyio.fail_after(30): with pytest.raises(LMStudioModelNotFoundError) as exc_info: await model.tokenize("Some text") diff --git a/tests/async/test_llm_async.py b/tests/async/test_llm_async.py index bfc2d98..6e94a0f 100644 --- a/tests/async/test_llm_async.py +++ b/tests/async/test_llm_async.py @@ -2,12 +2,18 @@ import logging +import anyio import pytest from pytest import LogCaptureFixture as LogCap -from lmstudio import AsyncClient, LlmLoadModelConfig, history +from lmstudio import ( + AsyncClient, + LlmLoadModelConfig, + LMStudioModelNotFoundError, + history, +) -from ..support import EXPECTED_LLM, EXPECTED_LLM_ID +from ..support import EXPECTED_LLM, EXPECTED_LLM_ID, check_sdk_error @pytest.mark.asyncio @@ -52,10 +58,14 @@ async def test_tokenize_async(model_id: str, caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) async with AsyncClient() as client: - response = await client.llm._tokenize(model_id, input=text) + model = await client.llm.model(model_id) + num_tokens = await model.count_tokens(text) + response = await model.tokenize(text) logging.info(f"Tokenization response: {response}") assert response assert isinstance(response, list) + # Ensure token count and tokenization are consistent + assert len(response) == num_tokens @pytest.mark.asyncio @@ -66,7 +76,8 @@ async def test_tokenize_list_async(model_id: str, caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) async with AsyncClient() as client: - response = await client.llm._tokenize(model_id, input=text) + model = await client.llm.model(model_id) + response = await model.tokenize(text) logging.info(f"Tokenization response: {response}") assert response assert isinstance(response, list) @@ -109,3 +120,33 @@ async def test_get_model_info_async(model_id: str, caplog: LogCap) -> None: response = await client.llm.get_model_info(model_id) logging.info(f"Model config response: {response}") assert response + + +@pytest.mark.asyncio +@pytest.mark.lmstudio +async def test_invalid_model_request_async(caplog: LogCap) -> None: + caplog.set_level(logging.DEBUG) + async with AsyncClient() as client: + # Deliberately create an invalid model handle + model = client.llm._create_handle("No such model") + # This should error rather than timing out, + # but avoid any risk of the client hanging... + with anyio.fail_after(30): + with pytest.raises(LMStudioModelNotFoundError) as exc_info: + await model.complete("Some text") + check_sdk_error(exc_info, __file__) + with anyio.fail_after(30): + with pytest.raises(LMStudioModelNotFoundError) as exc_info: + await model.respond("Some text") + check_sdk_error(exc_info, __file__) + with anyio.fail_after(30): + with pytest.raises(LMStudioModelNotFoundError) as exc_info: + await model.count_tokens("Some text") + with anyio.fail_after(30): + with pytest.raises(LMStudioModelNotFoundError) as exc_info: + await model.tokenize("Some text") + check_sdk_error(exc_info, __file__) + with anyio.fail_after(30): + with pytest.raises(LMStudioModelNotFoundError) as exc_info: + await model.get_context_length() + check_sdk_error(exc_info, __file__) diff --git a/tests/sync/test_embedding_sync.py b/tests/sync/test_embedding_sync.py index 89a39e6..223dabe 100644 --- a/tests/sync/test_embedding_sync.py +++ b/tests/sync/test_embedding_sync.py @@ -67,11 +67,15 @@ def test_tokenize_sync(model_id: str, caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) with Client() as client: - session = client.embedding - response = session._tokenize(model_id, input=text) + model = client.embedding.model(model_id) + num_tokens = model.count_tokens(text) + response = model.tokenize(text) logging.info(f"Tokenization response: {response}") assert response assert isinstance(response, list) + # Ensure token count and tokenization are consistent + # (embedding models add extra start/end markers during actual tokenization) + assert len(response) == num_tokens + 2 # the response should be deterministic if we set constant seed # so we can also check the value if desired @@ -83,8 +87,8 @@ def test_tokenize_list_sync(model_id: str, caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) with Client() as client: - session = client.embedding - response = session._tokenize(model_id, input=text) + model = client.embedding.model(model_id) + response = model.tokenize(text) logging.info(f"Tokenization response: {response}") assert response assert isinstance(response, list) @@ -141,6 +145,10 @@ def test_invalid_model_request_sync(caplog: LogCap) -> None: with pytest.raises(LMStudioModelNotFoundError) as exc_info: model.embed("Some text") check_sdk_error(exc_info, __file__) + with nullcontext(): + with pytest.raises(LMStudioModelNotFoundError) as exc_info: + model.count_tokens("Some text") + check_sdk_error(exc_info, __file__) with nullcontext(): with pytest.raises(LMStudioModelNotFoundError) as exc_info: model.tokenize("Some text") diff --git a/tests/sync/test_llm_sync.py b/tests/sync/test_llm_sync.py index 9945362..3b99cab 100644 --- a/tests/sync/test_llm_sync.py +++ b/tests/sync/test_llm_sync.py @@ -8,13 +8,19 @@ """Test non-inference methods on LLMs.""" import logging +from contextlib import nullcontext import pytest from pytest import LogCaptureFixture as LogCap -from lmstudio import Client, LlmLoadModelConfig, history +from lmstudio import ( + Client, + LlmLoadModelConfig, + LMStudioModelNotFoundError, + history, +) -from ..support import EXPECTED_LLM, EXPECTED_LLM_ID +from ..support import EXPECTED_LLM, EXPECTED_LLM_ID, check_sdk_error @pytest.mark.lmstudio @@ -55,10 +61,14 @@ def test_tokenize_sync(model_id: str, caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) with Client() as client: - response = client.llm._tokenize(model_id, input=text) + model = client.llm.model(model_id) + num_tokens = model.count_tokens(text) + response = model.tokenize(text) logging.info(f"Tokenization response: {response}") assert response assert isinstance(response, list) + # Ensure token count and tokenization are consistent + assert len(response) == num_tokens @pytest.mark.lmstudio @@ -68,7 +78,8 @@ def test_tokenize_list_sync(model_id: str, caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) with Client() as client: - response = client.llm._tokenize(model_id, input=text) + model = client.llm.model(model_id) + response = model.tokenize(text) logging.info(f"Tokenization response: {response}") assert response assert isinstance(response, list) @@ -108,3 +119,32 @@ def test_get_model_info_sync(model_id: str, caplog: LogCap) -> None: response = client.llm.get_model_info(model_id) logging.info(f"Model config response: {response}") assert response + + +@pytest.mark.lmstudio +def test_invalid_model_request_sync(caplog: LogCap) -> None: + caplog.set_level(logging.DEBUG) + with Client() as client: + # Deliberately create an invalid model handle + model = client.llm._create_handle("No such model") + # This should error rather than timing out, + # but avoid any risk of the client hanging... + with nullcontext(): + with pytest.raises(LMStudioModelNotFoundError) as exc_info: + model.complete("Some text") + check_sdk_error(exc_info, __file__) + with nullcontext(): + with pytest.raises(LMStudioModelNotFoundError) as exc_info: + model.respond("Some text") + check_sdk_error(exc_info, __file__) + with nullcontext(): + with pytest.raises(LMStudioModelNotFoundError) as exc_info: + model.count_tokens("Some text") + with nullcontext(): + with pytest.raises(LMStudioModelNotFoundError) as exc_info: + model.tokenize("Some text") + check_sdk_error(exc_info, __file__) + with nullcontext(): + with pytest.raises(LMStudioModelNotFoundError) as exc_info: + model.get_context_length() + check_sdk_error(exc_info, __file__)