diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index ee947d6..3bcd39f 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -1,6 +1,8 @@ """Async I/O protocol implementation for the LM Studio remote access API.""" import asyncio +import itertools +import time from abc import abstractmethod from contextlib import AsyncExitStack, asynccontextmanager @@ -28,16 +30,27 @@ from httpx_ws import AsyncWebSocketSession -from .sdk_api import LMStudioRuntimeError, sdk_public_api, sdk_public_api_async +from .sdk_api import ( + LMStudioRuntimeError, + LMStudioValueError, + sdk_callback_invocation, + sdk_public_api, + sdk_public_api_async, +) from .schemas import AnyLMStudioStruct, DictObject from .history import ( + AssistantResponse, Chat, ChatHistoryDataDict, FileHandle, LocalFileInput, _LocalFileData, + ToolCallResultData, + ToolCallRequest, + ToolResultMessage, ) from .json_api import ( + ActResult, AnyLoadConfig, AnyModelSpecifier, AvailableModelBase, @@ -63,6 +76,7 @@ LlmPredictionFragment, LMStudioCancelledError, LMStudioClientError, + LMStudioPredictionError, LMStudioWebsocket, LMStudioWebsocketError, LoadModelEndpoint, @@ -80,11 +94,15 @@ PredictionFragmentEvent, PredictionMessageCallback, PredictionResult, + PredictionRoundResult, + PredictionRxEvent, + PredictionToolCallEvent, PromptProcessingCallback, RemoteCallHandler, ResponseSchema, SendMessageAsync, TModelInfo, + ToolDefinition, check_model_namespace, load_struct, _model_spec_to_api_dict, @@ -895,14 +913,18 @@ async def __aexit__( await self.aclose() async def __aiter__(self) -> AsyncIterator[LlmPredictionFragment]: + async for event in self._iter_events(): + if isinstance(event, PredictionFragmentEvent): + yield event.arg + + async def _iter_events(self) -> AsyncIterator[PredictionRxEvent]: endpoint = self._endpoint async with self: assert self._channel is not None async for contents in self._channel.rx_stream(): for event in endpoint.iter_message_events(contents): endpoint.handle_rx_event(event) - if isinstance(event, PredictionFragmentEvent): - yield event.arg + yield event if endpoint.is_finished: break self._mark_finished() @@ -1244,6 +1266,199 @@ async def respond( pass return prediction_stream.result() + # TODO: Improve code sharing between sync and async multi-round predictions + # TODO: Accept async tools in the tools iterable + @sdk_public_api_async() + async def act( + self, + chat: Chat | ChatHistoryDataDict | str, + tools: Iterable[ToolDefinition], + *, + max_prediction_rounds: int | None = None, + max_parallel_tool_calls: int | None = 1, + config: LlmPredictionConfig | LlmPredictionConfigDict | None = None, + preset: str | None = None, + on_message: Callable[[AssistantResponse | ToolResultMessage], Any] + | None = None, + on_first_token: Callable[[int], Any] | None = None, + on_prediction_fragment: Callable[[LlmPredictionFragment, int], Any] + | None = None, + on_round_start: Callable[[int], Any] | None = None, + on_round_end: Callable[[int], Any] | None = None, + on_prediction_completed: Callable[[PredictionRoundResult], Any] | None = None, + on_prompt_processing_progress: Callable[[float, int], Any] | None = None, + handle_invalid_tool_request: Callable[ + [LMStudioPredictionError, ToolCallRequest | None], str | None + ] + | None = None, + ) -> ActResult: + """Request a response (with implicit tool use) in an ongoing agent chat session. + + Note: details of configuration fields may change in SDK feature releases. + """ + start_time = time.perf_counter() + # It is not yet possible to combine tool calling with requests for structured responses + response_format = None + agent_chat: Chat = Chat.from_history(chat) + del chat # Avoid any further access to the input chat history + # Multiple rounds, until all tool calls are resolved or limit is reached + round_counter: Iterable[int] + if max_prediction_rounds is not None: + if max_prediction_rounds < 1: + raise LMStudioValueError( + f"Max prediction rounds must be at least 1 ({max_prediction_rounds!r} given)" + ) + round_counter = range(max_prediction_rounds) + final_round_index = max_prediction_rounds - 1 + else: + # Do not force a final round when no limit is specified + final_round_index = -1 + round_counter = itertools.count() + llm_tool_args = ChatResponseEndpoint.parse_tools(tools) + del tools + # Supply the round index to any endpoint callbacks that expect one + round_index: int + on_first_token_for_endpoint: PredictionFirstTokenCallback | None = None + if on_first_token is not None: + + def _wrapped_on_first_token() -> None: + assert on_first_token is not None + on_first_token(round_index) + + on_first_token_for_endpoint = _wrapped_on_first_token + on_prediction_fragment_for_endpoint: PredictionFragmentCallback | None = None + if on_prediction_fragment is not None: + + def _wrapped_on_prediction_fragment( + fragment: LlmPredictionFragment, + ) -> None: + assert on_prediction_fragment is not None + on_prediction_fragment(fragment, round_index) + + on_prediction_fragment_for_endpoint = _wrapped_on_prediction_fragment + on_prompt_processing_for_endpoint: PromptProcessingCallback | None = None + if on_prompt_processing_progress is not None: + + def _wrapped_on_prompt_processing_progress(progress: float) -> None: + assert on_prompt_processing_progress is not None + on_prompt_processing_progress(progress, round_index) + + on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress + # TODO: Implementation to this point is common between the sync and async APIs + # Implementation past this point differs (as the sync API uses its own thread pool) + + # Request predictions until no more tool call requests are received in response + # (or the maximum number of prediction rounds is reached) + for round_index in round_counter: + self._logger.debug( + "Starting .act() prediction round", round_index=round_index + ) + if on_round_start is not None: + err_msg = f"Round start callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + on_round_start(round_index) + # Update the endpoint definition on each iteration in order to: + # * update the chat history with the previous round result + # * be able to disallow tool use when the rounds are limited + # TODO: Refactor endpoint API to avoid repeatedly performing the + # LlmPredictionConfig -> KvConfigStack transformation + endpoint = ChatResponseEndpoint( + self.identifier, + agent_chat, + response_format, + config, + preset, + None, # on_message is invoked directly + on_first_token_for_endpoint, + on_prediction_fragment_for_endpoint, + on_prompt_processing_for_endpoint, + handle_invalid_tool_request, + *(llm_tool_args if round_index != final_round_index else (None, None)), + ) + channel_cm = self._session._create_channel(endpoint) + prediction_stream = AsyncPredictionStream(channel_cm, endpoint) + tool_call_requests: list[ToolCallRequest] = [] + parsed_tool_calls: list[Callable[[], ToolCallResultData]] = [] + async for event in prediction_stream._iter_events(): + if isinstance(event, PredictionToolCallEvent): + tool_call_request = event.arg + tool_call_requests.append(tool_call_request) + # TODO: Also handle async tool calls here + tool_call = endpoint.request_tool_call(tool_call_request) + parsed_tool_calls.append(tool_call) + prediction = prediction_stream.result() + self._logger.debug( + "Completed .act() prediction round", round_index=round_index + ) + if on_prediction_completed: + round_result = PredictionRoundResult.from_result( + prediction, round_index + ) + err_msg = f"Prediction completed callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + on_prediction_completed(round_result) + if parsed_tool_calls: + if max_parallel_tool_calls is None: + max_parallel_tool_calls = len(parsed_tool_calls) + active_tool_calls = 0 + tool_call_futures: list[asyncio.Future[ToolCallResultData]] = [] + for tool_call in parsed_tool_calls: + if active_tool_calls >= max_parallel_tool_calls: + # Wait for a previous call to finish before starting another one + _done, pending = await asyncio.wait( + tool_call_futures, return_when=asyncio.FIRST_COMPLETED + ) + active_tool_calls = len(pending) + # TODO: Also handle async tool calls here + tool_call_futures.append( + asyncio.ensure_future(asyncio.to_thread(tool_call)) + ) + active_tool_calls += 1 + tool_call_results: list[ToolCallResultData] = [] + for tool_call_request, tool_call_future in zip( + tool_call_requests, tool_call_futures + ): + try: + await tool_call_future + except Exception as exc: + tool_call_result = endpoint._handle_failed_tool_request( + exc, tool_call_request + ) + else: + tool_call_result = tool_call_future.result() + tool_call_results.append(tool_call_result) + requests_message = agent_chat.add_assistant_response( + prediction, tool_call_requests + ) + results_message = agent_chat.add_tool_results(tool_call_results) + if on_message is not None: + err_msg = f"Tool request message callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + on_message(requests_message) + err_msg = f"Tool result message callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + on_message(results_message) + elif on_message is not None: + err_msg = f"Final response message callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + on_message(agent_chat.add_assistant_response(prediction)) + if on_round_end is not None: + err_msg = f"Round end callback failed for {self!r}" + with sdk_callback_invocation(err_msg, self._logger): + on_round_end(round_index) + if not tool_call_requests: + # No tool call requests -> we're done here + break + if round_index == final_round_index: + # We somehow received at least one tool call request, + # even though tools are omitted on the final round + err_msg = "Model requested tool use on final prediction round." + endpoint._handle_invalid_tool_request(err_msg) + break + num_rounds = round_index + 1 + duration = time.perf_counter() - start_time + return ActResult(rounds=num_rounds, total_time_seconds=duration) + @sdk_public_api_async() async def apply_prompt_template( self, diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index c4c9ae6..d8c42ea 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -43,13 +43,13 @@ from .schemas import AnyLMStudioStruct, DictObject from .history import ( AssistantResponse, - ToolResultMessage, Chat, ChatHistoryDataDict, FileHandle, LocalFileInput, _LocalFileData, ToolCallRequest, + ToolResultMessage, ) from .json_api import ( ActResult, @@ -1316,9 +1316,8 @@ def respond( pass return prediction_stream.result() - # Multi-round predictions are currently a sync-only handle-only feature - # TODO: Refactor to allow for more code sharing with the async API - # with defined aliases for the expected callback signatures + # TODO: Improve code sharing between sync and async multi-round predictions + # TODO: Accept async tools in the tools iterable @sdk_public_api() def act( self, @@ -1395,6 +1394,9 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None: on_prompt_processing_progress(progress, round_index) on_prompt_processing_for_endpoint = _wrapped_on_prompt_processing_progress + # TODO: Implementation to this point is common between the sync and async APIs + # Implementation past this point differs (as the async API uses the loop's executor) + # Request predictions until no more tool call requests are received in response # (or the maximum number of prediction rounds is reached) with ThreadPoolExecutor(max_parallel_tool_calls) as pool: @@ -1436,6 +1438,7 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None: if isinstance(event, PredictionToolCallEvent): tool_call_request = event.arg tool_call_requests.append(tool_call_request) + # TODO: Also handle async tool calls here tool_call = endpoint.request_tool_call(tool_call_request) pending_tool_calls[pool.submit(tool_call)] = tool_call_request prediction = prediction_stream.result() diff --git a/tests/async/test_inference_async.py b/tests/async/test_inference_async.py index 48ae8e2..76f506e 100644 --- a/tests/async/test_inference_async.py +++ b/tests/async/test_inference_async.py @@ -19,13 +19,17 @@ LlmPredictionFragment, LlmPredictionStats, LMStudioModelNotFoundError, + LMStudioPredictionError, LMStudioPresetNotFoundError, PredictionResult, + PredictionRoundResult, ResponseSchema, TextData, + ToolCallRequest, ) from ..support import ( + ADDITION_TOOL_SPEC, EXPECTED_LLM_ID, GBNF_GRAMMAR, PROMPT, @@ -33,7 +37,9 @@ RESPONSE_SCHEMA, SCHEMA_FIELDS, SHORT_PREDICTION_CONFIG, + TOOL_LLM_ID, check_sdk_error, + divide, ) @@ -388,3 +394,126 @@ async def test_cancel_prediction_async(caplog: LogCap) -> None: assert stream.stats.stop_reason == "userStopped" # ensure __aiter__ closes correctly assert num_times == 1 + + +@pytest.mark.asyncio +@pytest.mark.lmstudio +async def test_tool_using_agent_async(caplog: LogCap) -> None: + caplog.set_level(logging.DEBUG) + model_id = TOOL_LLM_ID + async with AsyncClient() as client: + llm = await client.llm.model(model_id) + chat = Chat() + chat.add_user_message("What is the sum of 123 and 3210?") + tools = [ADDITION_TOOL_SPEC] + # Ensure ignoring the round index passes static type checks + predictions: list[PredictionResult] = [] + + act_result = await llm.act( + chat, tools, on_prediction_completed=predictions.append + ) + assert len(predictions) > 1 + assert act_result.rounds == len(predictions) + assert "3333" in predictions[-1].content + + for _logger_name, log_level, message in caplog.record_tuples: + if log_level != logging.INFO: + continue + if message.startswith("Tool call:"): + break + else: + assert False, "Failed to find tool call logging entry" + assert "123" in message + assert "3210" in message + + +@pytest.mark.asyncio +@pytest.mark.lmstudio +async def test_tool_using_agent_callbacks_async(caplog: LogCap) -> None: + caplog.set_level(logging.DEBUG) + model_id = TOOL_LLM_ID + async with AsyncClient() as client: + llm = await client.llm.model(model_id) + chat = Chat() + # Ensure the first response is a combination of text and tool use requests + chat.add_user_message("First say 'Hi'. Then calculate 1 + 3 with the tool.") + tools = [ADDITION_TOOL_SPEC] + round_starts: list[int] = [] + round_ends: list[int] = [] + first_tokens: list[int] = [] + predictions: list[PredictionRoundResult] = [] + fragments: list[LlmPredictionFragment] = [] + fragment_round_indices: set[int] = set() + + def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None: + last_fragment_round_index = max(fragment_round_indices, default=-1) + assert round_index >= last_fragment_round_index + fragments.append(f) + fragment_round_indices.add(round_index) + + # TODO: Also check on_prompt_processing_progress and handling invalid messages + # (although it isn't clear how to provoke calls to the latter without mocking) + act_result = await llm.act( + chat, + tools, + on_first_token=first_tokens.append, + on_prediction_fragment=_append_fragment, + on_message=chat.append, + on_round_start=round_starts.append, + on_round_end=round_ends.append, + on_prediction_completed=predictions.append, + ) + num_rounds = act_result.rounds + sequential_round_indices = list(range(num_rounds)) + assert num_rounds > 1 + assert [p.round_index for p in predictions] == sequential_round_indices + assert round_starts == sequential_round_indices + assert round_ends == sequential_round_indices + expected_token_indices = [p.round_index for p in predictions if p.content] + assert expected_token_indices == sequential_round_indices + assert first_tokens == expected_token_indices + assert fragment_round_indices == set(expected_token_indices) + assert len(chat._messages) == 2 * num_rounds # No tool results in last round + + cloned_chat = chat.copy() + assert cloned_chat._messages == chat._messages + + +@pytest.mark.asyncio +@pytest.mark.lmstudio +async def test_tool_using_agent_error_handling_async(caplog: LogCap) -> None: + caplog.set_level(logging.DEBUG) + model_id = TOOL_LLM_ID + async with AsyncClient() as client: + llm = await client.llm.model(model_id) + chat = Chat() + chat.add_user_message( + "Attempt to divide 1 by 0 using the tool. Explain the result." + ) + tools = [divide] + predictions: list[PredictionRoundResult] = [] + request_failures: list[LMStudioPredictionError] = [] + + def _handle_invalid_request( + exc: LMStudioPredictionError, request: ToolCallRequest | None + ) -> None: + if request is not None: + request_failures.append(exc) + + act_result = await llm.act( + chat, + tools, + handle_invalid_tool_request=_handle_invalid_request, + on_prediction_completed=predictions.append, + ) + assert len(predictions) > 1 + assert act_result.rounds == len(predictions) + # Ensure the tool call failure was reported to the user callback + assert len(request_failures) == 1 + tool_failure_exc = request_failures[0] + assert isinstance(tool_failure_exc, LMStudioPredictionError) + assert isinstance(tool_failure_exc.__cause__, ZeroDivisionError) + # If the content checks prove too flaky in practice, they can be dropped + completed_response = predictions[-1].content.lower() + assert "divid" in completed_response # Accepts both "divide" and "dividing" + assert "zero" in completed_response diff --git a/tests/support/__init__.py b/tests/support/__init__.py index 230ef5f..0ec5c5d 100644 --- a/tests/support/__init__.py +++ b/tests/support/__init__.py @@ -1,5 +1,6 @@ """Common test support interfaces and expected value definitions.""" +import logging import sys from contextlib import closing, contextmanager @@ -21,6 +22,7 @@ LlmLoadModelConfig, LMStudioServerError, LMStudioChannelClosedError, + ToolFunctionDefDict, ) from lmstudio.json_api import ChannelEndpoint from lmstudio._sdk_models import LlmPredictionConfigDict, LlmStructuredPredictionSetting @@ -292,3 +294,30 @@ def check_unfiltered_error( ) from exc_info.value # Traceback should go all the way to the raising func assert tb.tb_frame.f_code is err_func.__code__ + + +#################################################### +# Tool definitions for tool use testing +#################################################### + + +def divide(numerator: float, denominator: float) -> float: + """Divide the given numerator by the given denominator. Return the result.""" + return numerator / denominator + + +def log_adding_two_integers(a: int, b: int) -> int: + """Log adding two integers together.""" + logging.info(f"Tool call: Adding {a!r} to {b!r} as integers") + return int(a) + int(b) + + +ADDITION_TOOL_SPEC: ToolFunctionDefDict = { + "name": "add", + "description": "Add two numbers", + "parameters": { + "a": int, + "b": int, + }, + "implementation": log_adding_two_integers, +} diff --git a/tests/sync/test_inference_sync.py b/tests/sync/test_inference_sync.py index e96c183..92dba9e 100644 --- a/tests/sync/test_inference_sync.py +++ b/tests/sync/test_inference_sync.py @@ -26,13 +26,17 @@ LlmPredictionFragment, LlmPredictionStats, LMStudioModelNotFoundError, + LMStudioPredictionError, LMStudioPresetNotFoundError, PredictionResult, + PredictionRoundResult, ResponseSchema, TextData, + ToolCallRequest, ) from ..support import ( + ADDITION_TOOL_SPEC, EXPECTED_LLM_ID, GBNF_GRAMMAR, PROMPT, @@ -40,7 +44,9 @@ RESPONSE_SCHEMA, SCHEMA_FIELDS, SHORT_PREDICTION_CONFIG, + TOOL_LLM_ID, check_sdk_error, + divide, ) @@ -377,3 +383,121 @@ def test_cancel_prediction_sync(caplog: LogCap) -> None: assert stream.stats.stop_reason == "userStopped" # ensure __aiter__ closes correctly assert num_times == 1 + + +@pytest.mark.lmstudio +def test_tool_using_agent_sync(caplog: LogCap) -> None: + caplog.set_level(logging.DEBUG) + model_id = TOOL_LLM_ID + with Client() as client: + llm = client.llm.model(model_id) + chat = Chat() + chat.add_user_message("What is the sum of 123 and 3210?") + tools = [ADDITION_TOOL_SPEC] + # Ensure ignoring the round index passes static type checks + predictions: list[PredictionResult] = [] + + act_result = llm.act(chat, tools, on_prediction_completed=predictions.append) + assert len(predictions) > 1 + assert act_result.rounds == len(predictions) + assert "3333" in predictions[-1].content + + for _logger_name, log_level, message in caplog.record_tuples: + if log_level != logging.INFO: + continue + if message.startswith("Tool call:"): + break + else: + assert False, "Failed to find tool call logging entry" + assert "123" in message + assert "3210" in message + + +@pytest.mark.lmstudio +def test_tool_using_agent_callbacks_sync(caplog: LogCap) -> None: + caplog.set_level(logging.DEBUG) + model_id = TOOL_LLM_ID + with Client() as client: + llm = client.llm.model(model_id) + chat = Chat() + # Ensure the first response is a combination of text and tool use requests + chat.add_user_message("First say 'Hi'. Then calculate 1 + 3 with the tool.") + tools = [ADDITION_TOOL_SPEC] + round_starts: list[int] = [] + round_ends: list[int] = [] + first_tokens: list[int] = [] + predictions: list[PredictionRoundResult] = [] + fragments: list[LlmPredictionFragment] = [] + fragment_round_indices: set[int] = set() + + def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None: + last_fragment_round_index = max(fragment_round_indices, default=-1) + assert round_index >= last_fragment_round_index + fragments.append(f) + fragment_round_indices.add(round_index) + + # TODO: Also check on_prompt_processing_progress and handling invalid messages + # (although it isn't clear how to provoke calls to the latter without mocking) + act_result = llm.act( + chat, + tools, + on_first_token=first_tokens.append, + on_prediction_fragment=_append_fragment, + on_message=chat.append, + on_round_start=round_starts.append, + on_round_end=round_ends.append, + on_prediction_completed=predictions.append, + ) + num_rounds = act_result.rounds + sequential_round_indices = list(range(num_rounds)) + assert num_rounds > 1 + assert [p.round_index for p in predictions] == sequential_round_indices + assert round_starts == sequential_round_indices + assert round_ends == sequential_round_indices + expected_token_indices = [p.round_index for p in predictions if p.content] + assert expected_token_indices == sequential_round_indices + assert first_tokens == expected_token_indices + assert fragment_round_indices == set(expected_token_indices) + assert len(chat._messages) == 2 * num_rounds # No tool results in last round + + cloned_chat = chat.copy() + assert cloned_chat._messages == chat._messages + + +@pytest.mark.lmstudio +def test_tool_using_agent_error_handling_sync(caplog: LogCap) -> None: + caplog.set_level(logging.DEBUG) + model_id = TOOL_LLM_ID + with Client() as client: + llm = client.llm.model(model_id) + chat = Chat() + chat.add_user_message( + "Attempt to divide 1 by 0 using the tool. Explain the result." + ) + tools = [divide] + predictions: list[PredictionRoundResult] = [] + request_failures: list[LMStudioPredictionError] = [] + + def _handle_invalid_request( + exc: LMStudioPredictionError, request: ToolCallRequest | None + ) -> None: + if request is not None: + request_failures.append(exc) + + act_result = llm.act( + chat, + tools, + handle_invalid_tool_request=_handle_invalid_request, + on_prediction_completed=predictions.append, + ) + assert len(predictions) > 1 + assert act_result.rounds == len(predictions) + # Ensure the tool call failure was reported to the user callback + assert len(request_failures) == 1 + tool_failure_exc = request_failures[0] + assert isinstance(tool_failure_exc, LMStudioPredictionError) + assert isinstance(tool_failure_exc.__cause__, ZeroDivisionError) + # If the content checks prove too flaky in practice, they can be dropped + completed_response = predictions[-1].content.lower() + assert "divid" in completed_response # Accepts both "divide" and "dividing" + assert "zero" in completed_response diff --git a/tests/test_inference.py b/tests/test_inference.py index 2f5b011..80ff704 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -12,25 +12,20 @@ from lmstudio import ( AsyncClient, Chat, - Client, LlmPredictionConfig, - LlmPredictionFragment, - LMStudioPredictionError, LMStudioValueError, PredictionResult, - PredictionRoundResult, - ToolCallRequest, ToolFunctionDef, - ToolFunctionDefDict, ) from lmstudio.json_api import ChatResponseEndpoint from lmstudio._sdk_models import LlmToolParameters from .support import ( + ADDITION_TOOL_SPEC, EXPECTED_LLM_ID, MAX_PREDICTED_TOKENS, SHORT_PREDICTION_CONFIG, - TOOL_LLM_ID, + log_adding_two_integers, ) SC_PREDICTION_CONFIG = { @@ -94,23 +89,6 @@ async def _request_response() -> PredictionResult: # TODO: write sync concurrent predictions test with external locking and concurrent.futures -def log_adding_two_integers(a: int, b: int) -> int: - """Log adding two integers together.""" - logging.info(f"Tool call: Adding {a!r} to {b!r} as integers") - return int(a) + int(b) - - -ADDITION_TOOL_SPEC: ToolFunctionDefDict = { - "name": "add", - "description": "Add two numbers", - "parameters": { - "a": int, - "b": int, - }, - "implementation": log_adding_two_integers, -} - - def test_tool_def_from_callable() -> None: default_def = ToolFunctionDef.from_callable(log_adding_two_integers) assert default_def == ToolFunctionDef( @@ -160,132 +138,3 @@ def test_duplicate_tool_names_rejected() -> None: LMStudioValueError, match="Duplicate tool names are not permitted" ): ChatResponseEndpoint.parse_tools(tools) - - -@pytest.mark.lmstudio -def test_tool_using_agent(caplog: LogCap) -> None: - # This is currently a sync-only API (it will be refactored in a future release) - - caplog.set_level(logging.DEBUG) - model_id = TOOL_LLM_ID - with Client() as client: - llm = client.llm.model(model_id) - chat = Chat() - chat.add_user_message("What is the sum of 123 and 3210?") - tools = [ADDITION_TOOL_SPEC] - # Ensure ignoring the round index passes static type checks - predictions: list[PredictionResult] = [] - - act_result = llm.act(chat, tools, on_prediction_completed=predictions.append) - assert len(predictions) > 1 - assert act_result.rounds == len(predictions) - assert "3333" in predictions[-1].content - - for _logger_name, log_level, message in caplog.record_tuples: - if log_level != logging.INFO: - continue - if message.startswith("Tool call:"): - break - else: - assert False, "Failed to find tool call logging entry" - assert "123" in message - assert "3210" in message - - -@pytest.mark.lmstudio -def test_tool_using_agent_callbacks(caplog: LogCap) -> None: - # This is currently a sync-only API (it will be refactored in a future release) - - caplog.set_level(logging.DEBUG) - model_id = TOOL_LLM_ID - with Client() as client: - llm = client.llm.model(model_id) - chat = Chat() - # Ensure the first response is a combination of text and tool use requests - chat.add_user_message("First say 'Hi'. Then calculate 1 + 3 with the tool.") - tools = [ADDITION_TOOL_SPEC] - round_starts: list[int] = [] - round_ends: list[int] = [] - first_tokens: list[int] = [] - predictions: list[PredictionRoundResult] = [] - fragments: list[LlmPredictionFragment] = [] - fragment_round_indices: set[int] = set() - - def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None: - last_fragment_round_index = max(fragment_round_indices, default=-1) - assert round_index >= last_fragment_round_index - fragments.append(f) - fragment_round_indices.add(round_index) - - # TODO: Also check on_prompt_processing_progress and handling invalid messages - # (although it isn't clear how to provoke calls to the latter without mocking) - act_result = llm.act( - chat, - tools, - on_first_token=first_tokens.append, - on_prediction_fragment=_append_fragment, - on_message=chat.append, - on_round_start=round_starts.append, - on_round_end=round_ends.append, - on_prediction_completed=predictions.append, - ) - num_rounds = act_result.rounds - sequential_round_indices = list(range(num_rounds)) - assert num_rounds > 1 - assert [p.round_index for p in predictions] == sequential_round_indices - assert round_starts == sequential_round_indices - assert round_ends == sequential_round_indices - expected_token_indices = [p.round_index for p in predictions if p.content] - assert expected_token_indices == sequential_round_indices - assert first_tokens == expected_token_indices - assert fragment_round_indices == set(expected_token_indices) - assert len(chat._messages) == 2 * num_rounds # No tool results in last round - - cloned_chat = chat.copy() - assert cloned_chat._messages == chat._messages - - -def divide(numerator: float, denominator: float) -> float: - """Divide the given numerator by the given denominator. Return the result.""" - return numerator / denominator - - -@pytest.mark.lmstudio -def test_tool_using_agent_error_handling(caplog: LogCap) -> None: - # This is currently a sync-only API (it will be refactored in a future release) - - caplog.set_level(logging.DEBUG) - model_id = TOOL_LLM_ID - with Client() as client: - llm = client.llm.model(model_id) - chat = Chat() - chat.add_user_message( - "Attempt to divide 1 by 0 using the tool. Explain the result." - ) - tools = [divide] - predictions: list[PredictionRoundResult] = [] - request_failures: list[LMStudioPredictionError] = [] - - def _handle_invalid_request( - exc: LMStudioPredictionError, request: ToolCallRequest | None - ) -> None: - if request is not None: - request_failures.append(exc) - - act_result = llm.act( - chat, - tools, - handle_invalid_tool_request=_handle_invalid_request, - on_prediction_completed=predictions.append, - ) - assert len(predictions) > 1 - assert act_result.rounds == len(predictions) - # Ensure the tool call failure was reported to the user callback - assert len(request_failures) == 1 - tool_failure_exc = request_failures[0] - assert isinstance(tool_failure_exc, LMStudioPredictionError) - assert isinstance(tool_failure_exc.__cause__, ZeroDivisionError) - # If the content checks prove too flaky in practice, they can be dropped - completed_response = predictions[-1].content.lower() - assert "divid" in completed_response # Accepts both "divide" and "dividing" - assert "zero" in completed_response