Skip to content

Commit 355598e

Browse files
committed
Rename ActionResult to just ActResult
1 parent d4c039c commit 355598e

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

src/lmstudio/json_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
# explicitly via `lmstudio.json_api`, it isn't exported
124124
# implicitly as part of the top-level `lmstudio` API.
125125
__all__ = [
126+
"ActResult",
126127
"AnyModelSpecifier",
127128
"EmbeddingModelInfo",
128129
"EmbeddingModelInstanceInfo",
@@ -151,7 +152,6 @@
151152
"ModelSpecifierDict",
152153
"ModelQuery",
153154
"ModelQueryDict",
154-
"ActionResult",
155155
"PredictionResult",
156156
"PredictionRoundResult",
157157
"SerializedLMSExtendedError",
@@ -471,7 +471,7 @@ def from_result(cls, result: PredictionResult[str], round_index: int) -> Self:
471471

472472

473473
@dataclass(kw_only=True, frozen=True, slots=True)
474-
class ActionResult:
474+
class ActResult:
475475
"""Summary of a completed multi-round tool using action."""
476476

477477
# Detailed action results are reported via callbacks (for now)

src/lmstudio/sync_api.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
_ToolCallRequest,
5454
)
5555
from .json_api import (
56-
ActionResult,
56+
ActResult,
5757
AnyModelSpecifier,
5858
AvailableModelBase,
5959
ChannelEndpoint,
@@ -1559,14 +1559,14 @@ def act(
15591559
[LMStudioPredictionError, _ToolCallRequest | None], str
15601560
]
15611561
| None = None,
1562-
) -> ActionResult:
1563-
"""Request a response (with implicit tool use) in an ongoing assistant chat session."""
1564-
action_start_time = time.perf_counter()
1562+
) -> ActResult:
1563+
"""Request a response (with implicit tool use) in an ongoing agent chat session."""
1564+
start_time = time.perf_counter()
15651565
# It is not yet possible to combine tool calling with requests for structured responses
15661566
response_format = None
15671567
if isinstance(chat, Chat):
15681568
chat._fetch_file_handles(self._session._fetch_file_handle)
1569-
action_chat: Chat = Chat.from_history(chat)
1569+
agent_chat: Chat = Chat.from_history(chat)
15701570
del chat
15711571
# Multiple rounds, until all tool calls are resolved or limit is reached
15721572
round_counter: Iterable[int]
@@ -1622,9 +1622,11 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
16221622
# Update the endpoint definition on each iteration in order to:
16231623
# * update the chat history with the previous round result
16241624
# * be able to disallow tool use when the rounds are limited
1625+
# TODO: Refactor endpoint API to avoid repeatedly performing the
1626+
# LlmPredictionConfig -> KvConfigStack transformation
16251627
endpoint = ChatResponseEndpoint(
16261628
self.identifier,
1627-
action_chat,
1629+
agent_chat,
16281630
response_format,
16291631
config,
16301632
None, # Multiple messages are generated per round
@@ -1658,23 +1660,29 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None:
16581660
tool_results = [
16591661
fut.result() for fut in as_completed(pending_tool_calls)
16601662
]
1661-
requests_message = action_chat._add_assistant_tool_requests(
1663+
requests_message = agent_chat._add_assistant_tool_requests(
16621664
prediction, tool_call_requests
16631665
)
1664-
results_message = action_chat._add_tool_results(tool_results)
1666+
results_message = agent_chat._add_tool_results(tool_results)
16651667
if on_message is not None:
16661668
on_message(requests_message)
16671669
on_message(results_message)
16681670
elif on_message is not None:
1669-
on_message(action_chat.add_assistant_response(prediction))
1671+
on_message(agent_chat.add_assistant_response(prediction))
16701672
if on_round_end is not None:
16711673
on_round_end(round_index)
16721674
if not tool_call_requests:
16731675
# No tool call requests -> we're done here
16741676
break
1677+
if round_index == final_round_index:
1678+
# We somehow received at least one tool call request,
1679+
# even though tools are omitted on the final round
1680+
err_msg = "Model requested tool use on final prediction round."
1681+
endpoint._handle_invalid_tool_request(err_msg)
1682+
break
16751683
num_rounds = round_index + 1
1676-
duration = time.perf_counter() - action_start_time
1677-
return ActionResult(rounds=num_rounds, total_time_seconds=duration)
1684+
duration = time.perf_counter() - start_time
1685+
return ActResult(rounds=num_rounds, total_time_seconds=duration)
16781686

16791687
@sdk_public_api()
16801688
def apply_prompt_template(

tests/test_inference.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def test_duplicate_tool_names_rejected() -> None:
162162

163163

164164
@pytest.mark.lmstudio
165-
def test_tool_action(caplog: LogCap) -> None:
165+
def test_tool_using_agent(caplog: LogCap) -> None:
166166
# This is currently a sync-only API (it will be refactored after 1.0.0)
167167

168168
caplog.set_level(logging.DEBUG)
@@ -177,9 +177,9 @@ def test_tool_action(caplog: LogCap) -> None:
177177
# Ensure ignoring the round index passes static type checks
178178
predictions: list[PredictionResult[str]] = []
179179

180-
action_result = llm.act(chat, tools, on_prediction_completed=predictions.append)
180+
act_result = llm.act(chat, tools, on_prediction_completed=predictions.append)
181181
assert len(predictions) > 1
182-
assert action_result.rounds == len(predictions)
182+
assert act_result.rounds == len(predictions)
183183
assert "220" in predictions[-1].content
184184

185185
for _logger_name, log_level, message in caplog.record_tuples:
@@ -194,7 +194,7 @@ def test_tool_action(caplog: LogCap) -> None:
194194

195195

196196
@pytest.mark.lmstudio
197-
def test_tool_action_callbacks(caplog: LogCap) -> None:
197+
def test_tool_using_agent_callbacks(caplog: LogCap) -> None:
198198
# This is currently a sync-only API (it will be refactored after 1.0.0)
199199

200200
caplog.set_level(logging.DEBUG)
@@ -222,7 +222,7 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
222222

223223
# TODO: Also check on_prompt_processing_progress and handling invalid messages
224224
# (although it isn't clear how to provoke calls to the latter without mocking)
225-
action_result = llm.act(
225+
act_result = llm.act(
226226
chat,
227227
tools,
228228
on_first_token=first_tokens.append,
@@ -232,7 +232,7 @@ def _append_fragment(f: LlmPredictionFragment, round_index: int) -> None:
232232
on_round_end=round_ends.append,
233233
on_prediction_completed=predictions.append,
234234
)
235-
num_rounds = action_result.rounds
235+
num_rounds = act_result.rounds
236236
sequential_round_indices = list(range(num_rounds))
237237
assert num_rounds > 1
238238
assert [p.round_index for p in predictions] == sequential_round_indices

0 commit comments

Comments
 (0)