|
53 | 53 | _ToolCallRequest, |
54 | 54 | ) |
55 | 55 | from .json_api import ( |
56 | | - ActionResult, |
| 56 | + ActResult, |
57 | 57 | AnyModelSpecifier, |
58 | 58 | AvailableModelBase, |
59 | 59 | ChannelEndpoint, |
@@ -1559,14 +1559,14 @@ def act( |
1559 | 1559 | [LMStudioPredictionError, _ToolCallRequest | None], str |
1560 | 1560 | ] |
1561 | 1561 | | None = None, |
1562 | | - ) -> ActionResult: |
1563 | | - """Request a response (with implicit tool use) in an ongoing assistant chat session.""" |
1564 | | - action_start_time = time.perf_counter() |
| 1562 | + ) -> ActResult: |
| 1563 | + """Request a response (with implicit tool use) in an ongoing agent chat session.""" |
| 1564 | + start_time = time.perf_counter() |
1565 | 1565 | # It is not yet possible to combine tool calling with requests for structured responses |
1566 | 1566 | response_format = None |
1567 | 1567 | if isinstance(chat, Chat): |
1568 | 1568 | chat._fetch_file_handles(self._session._fetch_file_handle) |
1569 | | - action_chat: Chat = Chat.from_history(chat) |
| 1569 | + agent_chat: Chat = Chat.from_history(chat) |
1570 | 1570 | del chat |
1571 | 1571 | # Multiple rounds, until all tool calls are resolved or limit is reached |
1572 | 1572 | round_counter: Iterable[int] |
@@ -1622,9 +1622,11 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None: |
1622 | 1622 | # Update the endpoint definition on each iteration in order to: |
1623 | 1623 | # * update the chat history with the previous round result |
1624 | 1624 | # * be able to disallow tool use when the rounds are limited |
| 1625 | + # TODO: Refactor endpoint API to avoid repeatedly performing the |
| 1626 | + # LlmPredictionConfig -> KvConfigStack transformation |
1625 | 1627 | endpoint = ChatResponseEndpoint( |
1626 | 1628 | self.identifier, |
1627 | | - action_chat, |
| 1629 | + agent_chat, |
1628 | 1630 | response_format, |
1629 | 1631 | config, |
1630 | 1632 | None, # Multiple messages are generated per round |
@@ -1658,23 +1660,29 @@ def _wrapped_on_prompt_processing_progress(progress: float) -> None: |
1658 | 1660 | tool_results = [ |
1659 | 1661 | fut.result() for fut in as_completed(pending_tool_calls) |
1660 | 1662 | ] |
1661 | | - requests_message = action_chat._add_assistant_tool_requests( |
| 1663 | + requests_message = agent_chat._add_assistant_tool_requests( |
1662 | 1664 | prediction, tool_call_requests |
1663 | 1665 | ) |
1664 | | - results_message = action_chat._add_tool_results(tool_results) |
| 1666 | + results_message = agent_chat._add_tool_results(tool_results) |
1665 | 1667 | if on_message is not None: |
1666 | 1668 | on_message(requests_message) |
1667 | 1669 | on_message(results_message) |
1668 | 1670 | elif on_message is not None: |
1669 | | - on_message(action_chat.add_assistant_response(prediction)) |
| 1671 | + on_message(agent_chat.add_assistant_response(prediction)) |
1670 | 1672 | if on_round_end is not None: |
1671 | 1673 | on_round_end(round_index) |
1672 | 1674 | if not tool_call_requests: |
1673 | 1675 | # No tool call requests -> we're done here |
1674 | 1676 | break |
| 1677 | + if round_index == final_round_index: |
| 1678 | + # We somehow received at least one tool call request, |
| 1679 | + # even though tools are omitted on the final round |
| 1680 | + err_msg = "Model requested tool use on final prediction round." |
| 1681 | + endpoint._handle_invalid_tool_request(err_msg) |
| 1682 | + break |
1675 | 1683 | num_rounds = round_index + 1 |
1676 | | - duration = time.perf_counter() - action_start_time |
1677 | | - return ActionResult(rounds=num_rounds, total_time_seconds=duration) |
| 1684 | + duration = time.perf_counter() - start_time |
| 1685 | + return ActResult(rounds=num_rounds, total_time_seconds=duration) |
1678 | 1686 |
|
1679 | 1687 | @sdk_public_api() |
1680 | 1688 | def apply_prompt_template( |
|
0 commit comments