Skip to content

Commit bf74ba9

Browse files
authored
Split Ollama entity (home-assistant#147769)
1 parent 70856bd commit bf74ba9

File tree

3 files changed

+268
-245
lines changed

3 files changed

+268
-245
lines changed

homeassistant/components/ollama/conversation.py

Lines changed: 8 additions & 243 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,18 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import AsyncGenerator, AsyncIterator, Callable
6-
import json
7-
import logging
8-
from typing import Any, Literal
9-
10-
import ollama
11-
from voluptuous_openapi import convert
5+
from typing import Literal
126

137
from homeassistant.components import assist_pipeline, conversation
148
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
159
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
1610
from homeassistant.core import HomeAssistant
17-
from homeassistant.exceptions import HomeAssistantError
18-
from homeassistant.helpers import device_registry as dr, intent, llm
11+
from homeassistant.helpers import intent
1912
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
2013

2114
from . import OllamaConfigEntry
22-
from .const import (
23-
CONF_KEEP_ALIVE,
24-
CONF_MAX_HISTORY,
25-
CONF_MODEL,
26-
CONF_NUM_CTX,
27-
CONF_PROMPT,
28-
CONF_THINK,
29-
DEFAULT_KEEP_ALIVE,
30-
DEFAULT_MAX_HISTORY,
31-
DEFAULT_NUM_CTX,
32-
DOMAIN,
33-
)
34-
from .models import MessageHistory, MessageRole
35-
36-
# Max number of back and forth with the LLM to generate a response
37-
MAX_TOOL_ITERATIONS = 10
38-
39-
_LOGGER = logging.getLogger(__name__)
15+
from .const import CONF_PROMPT, DOMAIN
16+
from .entity import OllamaBaseLLMEntity
4017

4118

4219
async def async_setup_entry(
@@ -55,147 +32,18 @@ async def async_setup_entry(
5532
)
5633

5734

58-
def _format_tool(
59-
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
60-
) -> dict[str, Any]:
61-
"""Format tool specification."""
62-
tool_spec = {
63-
"name": tool.name,
64-
"parameters": convert(tool.parameters, custom_serializer=custom_serializer),
65-
}
66-
if tool.description:
67-
tool_spec["description"] = tool.description
68-
return {"type": "function", "function": tool_spec}
69-
70-
71-
def _fix_invalid_arguments(value: Any) -> Any:
72-
"""Attempt to repair incorrectly formatted json function arguments.
73-
74-
Small models (for example llama3.1 8B) may produce invalid argument values
75-
which we attempt to repair here.
76-
"""
77-
if not isinstance(value, str):
78-
return value
79-
if (value.startswith("[") and value.endswith("]")) or (
80-
value.startswith("{") and value.endswith("}")
81-
):
82-
try:
83-
return json.loads(value)
84-
except json.decoder.JSONDecodeError:
85-
pass
86-
return value
87-
88-
89-
def _parse_tool_args(arguments: dict[str, Any]) -> dict[str, Any]:
90-
"""Rewrite ollama tool arguments.
91-
92-
This function improves tool use quality by fixing common mistakes made by
93-
small local tool use models. This will repair invalid json arguments and
94-
omit unnecessary arguments with empty values that will fail intent parsing.
95-
"""
96-
return {k: _fix_invalid_arguments(v) for k, v in arguments.items() if v}
97-
98-
99-
def _convert_content(
100-
chat_content: (
101-
conversation.Content
102-
| conversation.ToolResultContent
103-
| conversation.AssistantContent
104-
),
105-
) -> ollama.Message:
106-
"""Create tool response content."""
107-
if isinstance(chat_content, conversation.ToolResultContent):
108-
return ollama.Message(
109-
role=MessageRole.TOOL.value,
110-
content=json.dumps(chat_content.tool_result),
111-
)
112-
if isinstance(chat_content, conversation.AssistantContent):
113-
return ollama.Message(
114-
role=MessageRole.ASSISTANT.value,
115-
content=chat_content.content,
116-
tool_calls=[
117-
ollama.Message.ToolCall(
118-
function=ollama.Message.ToolCall.Function(
119-
name=tool_call.tool_name,
120-
arguments=tool_call.tool_args,
121-
)
122-
)
123-
for tool_call in chat_content.tool_calls or ()
124-
],
125-
)
126-
if isinstance(chat_content, conversation.UserContent):
127-
return ollama.Message(
128-
role=MessageRole.USER.value,
129-
content=chat_content.content,
130-
)
131-
if isinstance(chat_content, conversation.SystemContent):
132-
return ollama.Message(
133-
role=MessageRole.SYSTEM.value,
134-
content=chat_content.content,
135-
)
136-
raise TypeError(f"Unexpected content type: {type(chat_content)}")
137-
138-
139-
async def _transform_stream(
140-
result: AsyncIterator[ollama.ChatResponse],
141-
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
142-
"""Transform the response stream into HA format.
143-
144-
An Ollama streaming response may come in chunks like this:
145-
146-
response: message=Message(role="assistant", content="Paris")
147-
response: message=Message(role="assistant", content=".")
148-
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
149-
response: message=Message(role="assistant", tool_calls=[...])
150-
response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
151-
152-
This generator conforms to the chatlog delta stream expectations in that it
153-
yields deltas, then the role only once the response is done.
154-
"""
155-
156-
new_msg = True
157-
async for response in result:
158-
_LOGGER.debug("Received response: %s", response)
159-
response_message = response["message"]
160-
chunk: conversation.AssistantContentDeltaDict = {}
161-
if new_msg:
162-
new_msg = False
163-
chunk["role"] = "assistant"
164-
if (tool_calls := response_message.get("tool_calls")) is not None:
165-
chunk["tool_calls"] = [
166-
llm.ToolInput(
167-
tool_name=tool_call["function"]["name"],
168-
tool_args=_parse_tool_args(tool_call["function"]["arguments"]),
169-
)
170-
for tool_call in tool_calls
171-
]
172-
if (content := response_message.get("content")) is not None:
173-
chunk["content"] = content
174-
if response_message.get("done"):
175-
new_msg = True
176-
yield chunk
177-
178-
17935
class OllamaConversationEntity(
180-
conversation.ConversationEntity, conversation.AbstractConversationAgent
36+
conversation.ConversationEntity,
37+
conversation.AbstractConversationAgent,
38+
OllamaBaseLLMEntity,
18139
):
18240
"""Ollama conversation agent."""
18341

18442
_attr_supports_streaming = True
18543

18644
def __init__(self, entry: OllamaConfigEntry, subentry: ConfigSubentry) -> None:
18745
"""Initialize the agent."""
188-
self.entry = entry
189-
self.subentry = subentry
190-
self._attr_name = subentry.title
191-
self._attr_unique_id = subentry.subentry_id
192-
self._attr_device_info = dr.DeviceInfo(
193-
identifiers={(DOMAIN, subentry.subentry_id)},
194-
name=subentry.title,
195-
manufacturer="Ollama",
196-
model=entry.data[CONF_MODEL],
197-
entry_type=dr.DeviceEntryType.SERVICE,
198-
)
46+
super().__init__(entry, subentry)
19947
if self.subentry.data.get(CONF_LLM_HASS_API):
20048
self._attr_supported_features = (
20149
conversation.ConversationEntityFeature.CONTROL
@@ -255,89 +103,6 @@ async def _async_handle_message(
255103
continue_conversation=chat_log.continue_conversation,
256104
)
257105

258-
async def _async_handle_chat_log(
259-
self,
260-
chat_log: conversation.ChatLog,
261-
) -> None:
262-
"""Generate an answer for the chat log."""
263-
settings = {**self.entry.data, **self.subentry.data}
264-
265-
client = self.entry.runtime_data
266-
model = settings[CONF_MODEL]
267-
268-
tools: list[dict[str, Any]] | None = None
269-
if chat_log.llm_api:
270-
tools = [
271-
_format_tool(tool, chat_log.llm_api.custom_serializer)
272-
for tool in chat_log.llm_api.tools
273-
]
274-
275-
message_history: MessageHistory = MessageHistory(
276-
[_convert_content(content) for content in chat_log.content]
277-
)
278-
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
279-
self._trim_history(message_history, max_messages)
280-
281-
# Get response
282-
# To prevent infinite loops, we limit the number of iterations
283-
for _iteration in range(MAX_TOOL_ITERATIONS):
284-
try:
285-
response_generator = await client.chat(
286-
model=model,
287-
# Make a copy of the messages because we mutate the list later
288-
messages=list(message_history.messages),
289-
tools=tools,
290-
stream=True,
291-
# keep_alive requires specifying unit. In this case, seconds
292-
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
293-
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
294-
think=settings.get(CONF_THINK),
295-
)
296-
except (ollama.RequestError, ollama.ResponseError) as err:
297-
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)
298-
raise HomeAssistantError(
299-
f"Sorry, I had a problem talking to the Ollama server: {err}"
300-
) from err
301-
302-
message_history.messages.extend(
303-
[
304-
_convert_content(content)
305-
async for content in chat_log.async_add_delta_content_stream(
306-
self.entity_id, _transform_stream(response_generator)
307-
)
308-
]
309-
)
310-
311-
if not chat_log.unresponded_tool_results:
312-
break
313-
314-
def _trim_history(self, message_history: MessageHistory, max_messages: int) -> None:
315-
"""Trims excess messages from a single history.
316-
317-
This sets the max history to allow a configurable size history may take
318-
up in the context window.
319-
320-
Note that some messages in the history may not be from ollama only, and
321-
may come from other anents, so the assumptions here may not strictly hold,
322-
but generally should be effective.
323-
"""
324-
if max_messages < 1:
325-
# Keep all messages
326-
return
327-
328-
# Ignore the in progress user message
329-
num_previous_rounds = message_history.num_user_messages - 1
330-
if num_previous_rounds >= max_messages:
331-
# Trim history but keep system prompt (first message).
332-
# Every other message should be an assistant message, so keep 2x
333-
# message objects. Also keep the last in progress user message
334-
num_keep = 2 * max_messages + 1
335-
drop_index = len(message_history.messages) - num_keep
336-
message_history.messages = [
337-
message_history.messages[0],
338-
*message_history.messages[drop_index:],
339-
]
340-
341106
async def _async_entry_update_listener(
342107
self, hass: HomeAssistant, entry: ConfigEntry
343108
) -> None:

0 commit comments

Comments
 (0)