Skip to content

Commit c0f8916

Browse files
authored
Rename tool operations as tool actions (#9)
1 parent 4d30913 commit c0f8916

File tree

3 files changed

+34
-26
lines changed

3 files changed

+34
-26
lines changed

src/lmstudio/json_api.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
# explicitly via `lmstudio.json_api`, it isn't exported
124124
# implicitly as part of the top-level `lmstudio` API.
125125
__all__ = [
126+
"ActResult",
126127
"AnyModelSpecifier",
127128
"EmbeddingModelInfo",
128129
"EmbeddingModelInstanceInfo",
@@ -151,7 +152,6 @@
151152
"ModelSpecifierDict",
152153
"ModelQuery",
153154
"ModelQueryDict",
154-
"OperationResult",
155155
"PredictionResult",
156156
"PredictionRoundResult",
157157
"SerializedLMSExtendedError",
@@ -455,9 +455,9 @@ def _to_history_content(self) -> str:
455455

456456
@dataclass(kw_only=True, frozen=True, slots=True)
457457
class PredictionRoundResult(PredictionResult[str]):
458-
"""The result of a prediction within a multi-round tool using operation."""
458+
"""The result of a prediction within a multi-round tool using action."""
459459

460-
round_index: int # The round within the operation that produced this result
460+
round_index: int # The round within the action that produced this result
461461

462462
@classmethod
463463
def from_result(cls, result: PredictionResult[str], round_index: int) -> Self:
@@ -471,10 +471,10 @@ def from_result(cls, result: PredictionResult[str], round_index: int) -> Self:
471471

472472

473473
@dataclass(kw_only=True, frozen=True, slots=True)
474-
class OperationResult:
475-
"""Summary of a completed multi-round tool using operation."""
474+
class ActResult:
475+
"""Summary of a completed multi-round tool using action."""
476476

477-
# Actual operation output is reported via callbacks
477+
# Detailed action results are reported via callbacks (for now)
478478

479479
# fmt: off
480480
rounds: int
@@ -1073,7 +1073,7 @@ def __init__(
10731073
on_first_token: Callable[[], None] | None = None,
10741074
on_prediction_fragment: Callable[[LlmPredictionFragment], None] | None = None,
10751075
on_prompt_processing_progress: Callable[[float], None] | None = None,
1076-
# The remaining options are only relevant for multi-round tool operations
1076+
# The remaining options are only relevant for multi-round tool actions
10771077
handle_invalid_tool_request: Callable[
10781078
[LMStudioPredictionError, _ToolCallRequest | None], str
10791079
]
@@ -1359,7 +1359,7 @@ def parse_tools(
13591359
"""Split tool function definitions into server and client details."""
13601360
if not tools:
13611361
raise LMStudioValueError(
1362-
"Tool operation requires at least one tool to be defined."
1362+
"Tool using actions require at least one tool to be defined."
13631363
)
13641364
llm_tool_defs: list[LlmTool] = []
13651365
client_tool_map: dict[str, ClientToolSpec] = {}

src/lmstudio/sync_api.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
_ToolCallRequest,
5454
)
5555
from .json_api import (
56+
ActResult,
5657
AnyModelSpecifier,
5758
AvailableModelBase,
5859
ChannelEndpoint,
@@ -85,7 +86,6 @@
8586
ModelSessionTypes,
8687
ModelTypesEmbedding,
8788
ModelTypesLlm,
88-
OperationResult,
8989
PredictionEndpoint,
9090
PredictionFragmentEvent,
9191
PredictionResult,
@@ -1539,7 +1539,7 @@ def respond(
15391539
# Multi-round predictions are currently a sync-only handle-only feature
15401540
# TODO: Refactor to allow for more code sharing with the async API
15411541
@sdk_public_api()
1542-
def operate(
1542+
def act(
15431543
self,
15441544
chat: Chat | ChatHistoryDataDict | str,
15451545
tools: Iterable[ToolFunctionDef | ToolFunctionDefDict],
@@ -1559,14 +1559,14 @@ def operate(
15591559
[LMStudioPredictionError, _ToolCallRequest | None], str
15601560
]
15611561
| None = None,
1562-
) -> OperationResult:
1563-
"""Request a response (with implicit tool use) in an ongoing assistant chat session."""
1564-
operation_start_time = time.perf_counter()
1562+
) -> ActResult:
1563+
"""Request a response (with implicit tool use) in an ongoing agent chat session."""
1564+
start_time = time.perf_counter()
15651565
# It is not yet possible to combine tool calling with requests for structured responses
15661566
response_format = None
15671567
if isinstance(chat, Chat):
15681568
chat._fetch_file_handles(self._session._fetch_file_handle)
1569-
op_chat: Chat = Chat.from_history(chat)
1569+
agent_chat: Chat = Chat.from_history(chat)
15701570
del chat
15711571
# Multiple rounds, until all tool calls are resolved or limit is reached
15721572
round_counter: Iterable[int]
@@ -1622,9 +1622,11 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
16221622
# Update the endpoint definition on each iteration in order to:
16231623
# * update the chat history with the previous round result
16241624
# * be able to disallow tool use when the rounds are limited
1625+
# TODO: Refactor endpoint API to avoid repeatedly performing the
1626+
# LlmPredictionConfig -> KvConfigStack transformation
16251627
endpoint = ChatResponseEndpoint(
16261628
self.identifier,
1627-
op_chat,
1629+
agent_chat,
16281630
response_format,
16291631
config,
16301632
None, # Multiple messages are generated per round
@@ -1658,23 +1660,29 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
16581660
tool_results = [
16591661
fut.result() for fut in as_completed(pending_tool_calls)
16601662
]
1661-
requests_message = op_chat._add_assistant_tool_requests(
1663+
requests_message = agent_chat._add_assistant_tool_requests(
16621664
prediction, tool_call_requests
16631665
)
1664-
results_message = op_chat._add_tool_results(tool_results)
1666+
results_message = agent_chat._add_tool_results(tool_results)
16651667
if on_message is not None:
16661668
on_message(requests_message)
16671669
on_message(results_message)
16681670
elif on_message is not None:
1669-
on_message(op_chat.add_assistant_response(prediction))
1671+
on_message(agent_chat.add_assistant_response(prediction))
16701672
if on_round_end is not None:
16711673
on_round_end(round_index)
16721674
if not tool_call_requests:
16731675
# No tool call requests -> we're done here
16741676
break
1677+
if round_index == final_round_index:
1678+
# We somehow received at least one tool call request,
1679+
# even though tools are omitted on the final round
1680+
err_msg = "Model requested tool use on final prediction round."
1681+
endpoint._handle_invalid_tool_request(err_msg)
1682+
break
16751683
num_rounds = round_index + 1
1676-
duration = time.perf_counter() - operation_start_time
1677-
return OperationResult(rounds=num_rounds, total_time_seconds=duration)
1684+
duration = time.perf_counter() - start_time
1685+
return ActResult(rounds=num_rounds, total_time_seconds=duration)
16781686

16791687
@sdk_public_api()
16801688
def apply_prompt_template(

tests/test_inference.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def test_duplicate_tool_names_rejected() -> None:
162162

163163

164164
@pytest.mark.lmstudio
165-
def test_tool_operation(caplog: LogCap) -> None:
165+
def test_tool_using_agent(caplog: LogCap) -> None:
166166
# This is currently a sync-only API (it will be refactored after 1.0.0)
167167

168168
caplog.set_level(logging.DEBUG)
@@ -177,9 +177,9 @@ def test_tool_operation(caplog: LogCap) -> None:
177177
# Ensure ignoring the round index passes static type checks
178178
predictions: list[PredictionResult[str]] = []
179179

180-
op_result = llm.operate(chat, tools, on_prediction_completed=predictions.append)
180+
act_result = llm.act(chat, tools, on_prediction_completed=predictions.append)
181181
assert len(predictions) > 1
182-
assert op_result.rounds == len(predictions)
182+
assert act_result.rounds == len(predictions)
183183
assert "220" in predictions[-1].content
184184

185185
for _logger_name, log_level, message in caplog.record_tuples:
@@ -194,7 +194,7 @@ def test_tool_operation(caplog: LogCap) -> None:
194194

195195

196196
@pytest.mark.lmstudio
197-
def test_tool_operation_callbacks(caplog: LogCap) -> None:
197+
def test_tool_using_agent_callbacks(caplog: LogCap) -> None:
198198
# This is currently a sync-only API (it will be refactored after 1.0.0)
199199

200200
caplog.set_level(logging.DEBUG)
@@ -222,7 +222,7 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
222222

223223
# TODO: Also check on_prompt_processing_progress and handling invalid messages
224224
# (although it isn't clear how to provoke calls to the latter without mocking)
225-
op_result = llm.operate(
225+
act_result = llm.act(
226226
chat,
227227
tools,
228228
on_first_token=first_tokens.append,
@@ -232,7 +232,7 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
232232
on_round_end=round_ends.append,
233233
on_prediction_completed=predictions.append,
234234
)
235-
num_rounds = op_result.rounds
235+
num_rounds = act_result.rounds
236236
sequential_round_indices = list(range(num_rounds))
237237
assert num_rounds > 1
238238
assert [p.round_index for p in predictions] == sequential_round_indices

0 commit comments

Comments
 (0)