Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion src/lmstudio/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
)
from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config
from ._sdk_models import (
EmbeddingRpcCountTokensParameter,
EmbeddingRpcEmbedStringParameter,
EmbeddingRpcTokenizeParameter,
LlmApplyPromptTemplateOpts,
Expand Down Expand Up @@ -733,6 +734,18 @@ async def _get_context_length(self, model_specifier: AnyModelSpecifier) -> int:
raw_model_info = await self._get_api_model_info(model_specifier)
return int(raw_model_info.get("contextLength", -1))

async def _count_tokens(
self, model_specifier: AnyModelSpecifier, input: str
) -> int:
params = EmbeddingRpcCountTokensParameter._from_api_dict(
{
"specifier": _model_spec_to_api_dict(model_specifier),
"inputString": input,
}
)
response = await self.remote_call("countTokens", params)
return int(response["tokenCount"])

# Private helper method to allow the main API to easily accept iterables
async def _tokenize_text(
self, model_specifier: AnyModelSpecifier, input: str
Expand All @@ -748,7 +761,6 @@ async def _tokenize_text(

# Alas, type hints don't properly support distinguishing str vs Iterable[str]:
# https://github.com/python/typing/issues/256
@sdk_public_api_async()
async def _tokenize(
self, model_specifier: AnyModelSpecifier, input: str | Iterable[str]
) -> Sequence[int] | Sequence[Sequence[int]]:
Expand Down Expand Up @@ -1191,6 +1203,11 @@ async def tokenize(
"""Tokenize the input string(s) using this model."""
return await self._session._tokenize(self.identifier, input)

@sdk_public_api_async()
async def count_tokens(self, input: str) -> int:
"""Report the number of tokens needed for the input string using this model."""
return await self._session._count_tokens(self.identifier, input)

@sdk_public_api_async()
async def get_context_length(self) -> int:
"""Get the context length of this model."""
Expand Down
16 changes: 16 additions & 0 deletions src/lmstudio/sync_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
)
from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config
from ._sdk_models import (
EmbeddingRpcCountTokensParameter,
EmbeddingRpcEmbedStringParameter,
EmbeddingRpcTokenizeParameter,
LlmApplyPromptTemplateOpts,
Expand Down Expand Up @@ -902,6 +903,16 @@ def _get_context_length(self, model_specifier: AnyModelSpecifier) -> int:
raw_model_info = self._get_api_model_info(model_specifier)
return int(raw_model_info.get("contextLength", -1))

def _count_tokens(self, model_specifier: AnyModelSpecifier, input: str) -> int:
params = EmbeddingRpcCountTokensParameter._from_api_dict(
{
"specifier": _model_spec_to_api_dict(model_specifier),
"inputString": input,
}
)
response = self.remote_call("countTokens", params)
return int(response["tokenCount"])

# Private helper method to allow the main API to easily accept iterables
def _tokenize_text(
self, model_specifier: AnyModelSpecifier, input: str
Expand Down Expand Up @@ -1353,6 +1364,11 @@ def tokenize(
"""Tokenize the input string(s) using this model."""
return self._session._tokenize(self.identifier, input)

@sdk_public_api()
def count_tokens(self, input: str) -> int:
"""Report the number of tokens needed for the input string using this model."""
return self._session._count_tokens(self.identifier, input)

@sdk_public_api()
def get_context_length(self) -> int:
"""Get the context length of this model."""
Expand Down
16 changes: 12 additions & 4 deletions tests/async/test_embedding_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,15 @@ async def test_tokenize_async(model_id: str, caplog: LogCap) -> None:

caplog.set_level(logging.DEBUG)
async with AsyncClient() as client:
session = client.embedding
response = await session._tokenize(model_id, input=text)
model = await client.embedding.model(model_id)
num_tokens = await model.count_tokens(text)
response = await model.tokenize(text)
logging.info(f"Tokenization response: {response}")
assert response
assert isinstance(response, list)
# Ensure token count and tokenization are consistent
# (embedding models add extra start/end markers during actual tokenization)
assert len(response) == num_tokens + 2
# the response should be deterministic if we set constant seed
# so we can also check the value if desired

Expand All @@ -80,8 +84,8 @@ async def test_tokenize_list_async(model_id: str, caplog: LogCap) -> None:

caplog.set_level(logging.DEBUG)
async with AsyncClient() as client:
session = client.embedding
response = await session._tokenize(model_id, input=text)
model = await client.embedding.model(model_id)
response = await model.tokenize(text)
logging.info(f"Tokenization response: {response}")
assert response
assert isinstance(response, list)
Expand Down Expand Up @@ -142,6 +146,10 @@ async def test_invalid_model_request_async(caplog: LogCap) -> None:
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
await model.embed("Some text")
check_sdk_error(exc_info, __file__)
with anyio.fail_after(30):
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
await model.count_tokens("Some text")
check_sdk_error(exc_info, __file__)
with anyio.fail_after(30):
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
await model.tokenize("Some text")
Expand Down
49 changes: 45 additions & 4 deletions tests/async/test_llm_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

import logging

import anyio
import pytest
from pytest import LogCaptureFixture as LogCap

from lmstudio import AsyncClient, LlmLoadModelConfig, history
from lmstudio import (
AsyncClient,
LlmLoadModelConfig,
LMStudioModelNotFoundError,
history,
)

from ..support import EXPECTED_LLM, EXPECTED_LLM_ID
from ..support import EXPECTED_LLM, EXPECTED_LLM_ID, check_sdk_error


@pytest.mark.asyncio
Expand Down Expand Up @@ -52,10 +58,14 @@ async def test_tokenize_async(model_id: str, caplog: LogCap) -> None:

caplog.set_level(logging.DEBUG)
async with AsyncClient() as client:
response = await client.llm._tokenize(model_id, input=text)
model = await client.llm.model(model_id)
num_tokens = await model.count_tokens(text)
response = await model.tokenize(text)
logging.info(f"Tokenization response: {response}")
assert response
assert isinstance(response, list)
# Ensure token count and tokenization are consistent
assert len(response) == num_tokens


@pytest.mark.asyncio
Expand All @@ -66,7 +76,8 @@ async def test_tokenize_list_async(model_id: str, caplog: LogCap) -> None:

caplog.set_level(logging.DEBUG)
async with AsyncClient() as client:
response = await client.llm._tokenize(model_id, input=text)
model = await client.llm.model(model_id)
response = await model.tokenize(text)
logging.info(f"Tokenization response: {response}")
assert response
assert isinstance(response, list)
Expand Down Expand Up @@ -109,3 +120,33 @@ async def test_get_model_info_async(model_id: str, caplog: LogCap) -> None:
response = await client.llm.get_model_info(model_id)
logging.info(f"Model config response: {response}")
assert response


@pytest.mark.asyncio
@pytest.mark.lmstudio
async def test_invalid_model_request_async(caplog: LogCap) -> None:
caplog.set_level(logging.DEBUG)
async with AsyncClient() as client:
# Deliberately create an invalid model handle
model = client.llm._create_handle("No such model")
# This should error rather than timing out,
# but avoid any risk of the client hanging...
with anyio.fail_after(30):
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
await model.complete("Some text")
check_sdk_error(exc_info, __file__)
with anyio.fail_after(30):
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
await model.respond("Some text")
check_sdk_error(exc_info, __file__)
with anyio.fail_after(30):
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
await model.count_tokens("Some text")
with anyio.fail_after(30):
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
await model.tokenize("Some text")
check_sdk_error(exc_info, __file__)
with anyio.fail_after(30):
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
await model.get_context_length()
check_sdk_error(exc_info, __file__)
16 changes: 12 additions & 4 deletions tests/sync/test_embedding_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,15 @@ def test_tokenize_sync(model_id: str, caplog: LogCap) -> None:

caplog.set_level(logging.DEBUG)
with Client() as client:
session = client.embedding
response = session._tokenize(model_id, input=text)
model = client.embedding.model(model_id)
num_tokens = model.count_tokens(text)
response = model.tokenize(text)
logging.info(f"Tokenization response: {response}")
assert response
assert isinstance(response, list)
# Ensure token count and tokenization are consistent
# (embedding models add extra start/end markers during actual tokenization)
assert len(response) == num_tokens + 2
# the response should be deterministic if we set constant seed
# so we can also check the value if desired

Expand All @@ -83,8 +87,8 @@ def test_tokenize_list_sync(model_id: str, caplog: LogCap) -> None:

caplog.set_level(logging.DEBUG)
with Client() as client:
session = client.embedding
response = session._tokenize(model_id, input=text)
model = client.embedding.model(model_id)
response = model.tokenize(text)
logging.info(f"Tokenization response: {response}")
assert response
assert isinstance(response, list)
Expand Down Expand Up @@ -141,6 +145,10 @@ def test_invalid_model_request_sync(caplog: LogCap) -> None:
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
model.embed("Some text")
check_sdk_error(exc_info, __file__)
with nullcontext():
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
model.count_tokens("Some text")
check_sdk_error(exc_info, __file__)
with nullcontext():
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
model.tokenize("Some text")
Expand Down
48 changes: 44 additions & 4 deletions tests/sync/test_llm_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@
"""Test non-inference methods on LLMs."""

import logging
from contextlib import nullcontext

import pytest
from pytest import LogCaptureFixture as LogCap

from lmstudio import Client, LlmLoadModelConfig, history
from lmstudio import (
Client,
LlmLoadModelConfig,
LMStudioModelNotFoundError,
history,
)

from ..support import EXPECTED_LLM, EXPECTED_LLM_ID
from ..support import EXPECTED_LLM, EXPECTED_LLM_ID, check_sdk_error


@pytest.mark.lmstudio
Expand Down Expand Up @@ -55,10 +61,14 @@ def test_tokenize_sync(model_id: str, caplog: LogCap) -> None:

caplog.set_level(logging.DEBUG)
with Client() as client:
response = client.llm._tokenize(model_id, input=text)
model = client.llm.model(model_id)
num_tokens = model.count_tokens(text)
response = model.tokenize(text)
logging.info(f"Tokenization response: {response}")
assert response
assert isinstance(response, list)
# Ensure token count and tokenization are consistent
assert len(response) == num_tokens


@pytest.mark.lmstudio
Expand All @@ -68,7 +78,8 @@ def test_tokenize_list_sync(model_id: str, caplog: LogCap) -> None:

caplog.set_level(logging.DEBUG)
with Client() as client:
response = client.llm._tokenize(model_id, input=text)
model = client.llm.model(model_id)
response = model.tokenize(text)
logging.info(f"Tokenization response: {response}")
assert response
assert isinstance(response, list)
Expand Down Expand Up @@ -108,3 +119,32 @@ def test_get_model_info_sync(model_id: str, caplog: LogCap) -> None:
response = client.llm.get_model_info(model_id)
logging.info(f"Model config response: {response}")
assert response


@pytest.mark.lmstudio
def test_invalid_model_request_sync(caplog: LogCap) -> None:
caplog.set_level(logging.DEBUG)
with Client() as client:
# Deliberately create an invalid model handle
model = client.llm._create_handle("No such model")
# This should error rather than timing out,
# but avoid any risk of the client hanging...
with nullcontext():
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
model.complete("Some text")
check_sdk_error(exc_info, __file__)
with nullcontext():
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
model.respond("Some text")
check_sdk_error(exc_info, __file__)
with nullcontext():
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
model.count_tokens("Some text")
with nullcontext():
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
model.tokenize("Some text")
check_sdk_error(exc_info, __file__)
with nullcontext():
with pytest.raises(LMStudioModelNotFoundError) as exc_info:
model.get_context_length()
check_sdk_error(exc_info, __file__)
Loading