22
33from __future__ import annotations
44
5- from collections .abc import AsyncGenerator , AsyncIterator , Callable
6- import json
7- import logging
8- from typing import Any , Literal
9-
10- import ollama
11- from voluptuous_openapi import convert
5+ from typing import Literal
126
137from homeassistant .components import assist_pipeline , conversation
148from homeassistant .config_entries import ConfigEntry , ConfigSubentry
159from homeassistant .const import CONF_LLM_HASS_API , MATCH_ALL
1610from homeassistant .core import HomeAssistant
17- from homeassistant .exceptions import HomeAssistantError
18- from homeassistant .helpers import device_registry as dr , intent , llm
11+ from homeassistant .helpers import intent
1912from homeassistant .helpers .entity_platform import AddConfigEntryEntitiesCallback
2013
2114from . import OllamaConfigEntry
22- from .const import (
23- CONF_KEEP_ALIVE ,
24- CONF_MAX_HISTORY ,
25- CONF_MODEL ,
26- CONF_NUM_CTX ,
27- CONF_PROMPT ,
28- CONF_THINK ,
29- DEFAULT_KEEP_ALIVE ,
30- DEFAULT_MAX_HISTORY ,
31- DEFAULT_NUM_CTX ,
32- DOMAIN ,
33- )
34- from .models import MessageHistory , MessageRole
35-
36- # Max number of back and forth with the LLM to generate a response
37- MAX_TOOL_ITERATIONS = 10
38-
39- _LOGGER = logging .getLogger (__name__ )
15+ from .const import CONF_PROMPT , DOMAIN
16+ from .entity import OllamaBaseLLMEntity
4017
4118
4219async def async_setup_entry (
@@ -55,147 +32,18 @@ async def async_setup_entry(
5532 )
5633
5734
58- def _format_tool (
59- tool : llm .Tool , custom_serializer : Callable [[Any ], Any ] | None
60- ) -> dict [str , Any ]:
61- """Format tool specification."""
62- tool_spec = {
63- "name" : tool .name ,
64- "parameters" : convert (tool .parameters , custom_serializer = custom_serializer ),
65- }
66- if tool .description :
67- tool_spec ["description" ] = tool .description
68- return {"type" : "function" , "function" : tool_spec }
69-
70-
71- def _fix_invalid_arguments (value : Any ) -> Any :
72- """Attempt to repair incorrectly formatted json function arguments.
73-
74- Small models (for example llama3.1 8B) may produce invalid argument values
75- which we attempt to repair here.
76- """
77- if not isinstance (value , str ):
78- return value
79- if (value .startswith ("[" ) and value .endswith ("]" )) or (
80- value .startswith ("{" ) and value .endswith ("}" )
81- ):
82- try :
83- return json .loads (value )
84- except json .decoder .JSONDecodeError :
85- pass
86- return value
87-
88-
89- def _parse_tool_args (arguments : dict [str , Any ]) -> dict [str , Any ]:
90- """Rewrite ollama tool arguments.
91-
92- This function improves tool use quality by fixing common mistakes made by
93- small local tool use models. This will repair invalid json arguments and
94- omit unnecessary arguments with empty values that will fail intent parsing.
95- """
96- return {k : _fix_invalid_arguments (v ) for k , v in arguments .items () if v }
97-
98-
99- def _convert_content (
100- chat_content : (
101- conversation .Content
102- | conversation .ToolResultContent
103- | conversation .AssistantContent
104- ),
105- ) -> ollama .Message :
106- """Create tool response content."""
107- if isinstance (chat_content , conversation .ToolResultContent ):
108- return ollama .Message (
109- role = MessageRole .TOOL .value ,
110- content = json .dumps (chat_content .tool_result ),
111- )
112- if isinstance (chat_content , conversation .AssistantContent ):
113- return ollama .Message (
114- role = MessageRole .ASSISTANT .value ,
115- content = chat_content .content ,
116- tool_calls = [
117- ollama .Message .ToolCall (
118- function = ollama .Message .ToolCall .Function (
119- name = tool_call .tool_name ,
120- arguments = tool_call .tool_args ,
121- )
122- )
123- for tool_call in chat_content .tool_calls or ()
124- ],
125- )
126- if isinstance (chat_content , conversation .UserContent ):
127- return ollama .Message (
128- role = MessageRole .USER .value ,
129- content = chat_content .content ,
130- )
131- if isinstance (chat_content , conversation .SystemContent ):
132- return ollama .Message (
133- role = MessageRole .SYSTEM .value ,
134- content = chat_content .content ,
135- )
136- raise TypeError (f"Unexpected content type: { type (chat_content )} " )
137-
138-
139- async def _transform_stream (
140- result : AsyncIterator [ollama .ChatResponse ],
141- ) -> AsyncGenerator [conversation .AssistantContentDeltaDict ]:
142- """Transform the response stream into HA format.
143-
144- An Ollama streaming response may come in chunks like this:
145-
146- response: message=Message(role="assistant", content="Paris")
147- response: message=Message(role="assistant", content=".")
148- response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
149- response: message=Message(role="assistant", tool_calls=[...])
150- response: message=Message(role="assistant", content=""), done: True, done_reason: "stop"
151-
152- This generator conforms to the chatlog delta stream expectations in that it
153- yields deltas, then the role only once the response is done.
154- """
155-
156- new_msg = True
157- async for response in result :
158- _LOGGER .debug ("Received response: %s" , response )
159- response_message = response ["message" ]
160- chunk : conversation .AssistantContentDeltaDict = {}
161- if new_msg :
162- new_msg = False
163- chunk ["role" ] = "assistant"
164- if (tool_calls := response_message .get ("tool_calls" )) is not None :
165- chunk ["tool_calls" ] = [
166- llm .ToolInput (
167- tool_name = tool_call ["function" ]["name" ],
168- tool_args = _parse_tool_args (tool_call ["function" ]["arguments" ]),
169- )
170- for tool_call in tool_calls
171- ]
172- if (content := response_message .get ("content" )) is not None :
173- chunk ["content" ] = content
174- if response_message .get ("done" ):
175- new_msg = True
176- yield chunk
177-
178-
17935class OllamaConversationEntity (
180- conversation .ConversationEntity , conversation .AbstractConversationAgent
36+ conversation .ConversationEntity ,
37+ conversation .AbstractConversationAgent ,
38+ OllamaBaseLLMEntity ,
18139):
18240 """Ollama conversation agent."""
18341
18442 _attr_supports_streaming = True
18543
18644 def __init__ (self , entry : OllamaConfigEntry , subentry : ConfigSubentry ) -> None :
18745 """Initialize the agent."""
188- self .entry = entry
189- self .subentry = subentry
190- self ._attr_name = subentry .title
191- self ._attr_unique_id = subentry .subentry_id
192- self ._attr_device_info = dr .DeviceInfo (
193- identifiers = {(DOMAIN , subentry .subentry_id )},
194- name = subentry .title ,
195- manufacturer = "Ollama" ,
196- model = entry .data [CONF_MODEL ],
197- entry_type = dr .DeviceEntryType .SERVICE ,
198- )
46+ super ().__init__ (entry , subentry )
19947 if self .subentry .data .get (CONF_LLM_HASS_API ):
20048 self ._attr_supported_features = (
20149 conversation .ConversationEntityFeature .CONTROL
@@ -255,89 +103,6 @@ async def _async_handle_message(
255103 continue_conversation = chat_log .continue_conversation ,
256104 )
257105
258- async def _async_handle_chat_log (
259- self ,
260- chat_log : conversation .ChatLog ,
261- ) -> None :
262- """Generate an answer for the chat log."""
263- settings = {** self .entry .data , ** self .subentry .data }
264-
265- client = self .entry .runtime_data
266- model = settings [CONF_MODEL ]
267-
268- tools : list [dict [str , Any ]] | None = None
269- if chat_log .llm_api :
270- tools = [
271- _format_tool (tool , chat_log .llm_api .custom_serializer )
272- for tool in chat_log .llm_api .tools
273- ]
274-
275- message_history : MessageHistory = MessageHistory (
276- [_convert_content (content ) for content in chat_log .content ]
277- )
278- max_messages = int (settings .get (CONF_MAX_HISTORY , DEFAULT_MAX_HISTORY ))
279- self ._trim_history (message_history , max_messages )
280-
281- # Get response
282- # To prevent infinite loops, we limit the number of iterations
283- for _iteration in range (MAX_TOOL_ITERATIONS ):
284- try :
285- response_generator = await client .chat (
286- model = model ,
287- # Make a copy of the messages because we mutate the list later
288- messages = list (message_history .messages ),
289- tools = tools ,
290- stream = True ,
291- # keep_alive requires specifying unit. In this case, seconds
292- keep_alive = f"{ settings .get (CONF_KEEP_ALIVE , DEFAULT_KEEP_ALIVE )} s" ,
293- options = {CONF_NUM_CTX : settings .get (CONF_NUM_CTX , DEFAULT_NUM_CTX )},
294- think = settings .get (CONF_THINK ),
295- )
296- except (ollama .RequestError , ollama .ResponseError ) as err :
297- _LOGGER .error ("Unexpected error talking to Ollama server: %s" , err )
298- raise HomeAssistantError (
299- f"Sorry, I had a problem talking to the Ollama server: { err } "
300- ) from err
301-
302- message_history .messages .extend (
303- [
304- _convert_content (content )
305- async for content in chat_log .async_add_delta_content_stream (
306- self .entity_id , _transform_stream (response_generator )
307- )
308- ]
309- )
310-
311- if not chat_log .unresponded_tool_results :
312- break
313-
314- def _trim_history (self , message_history : MessageHistory , max_messages : int ) -> None :
315- """Trims excess messages from a single history.
316-
317- This sets the max history to allow a configurable size history may take
318- up in the context window.
319-
320- Note that some messages in the history may not be from ollama only, and
321- may come from other anents, so the assumptions here may not strictly hold,
322- but generally should be effective.
323- """
324- if max_messages < 1 :
325- # Keep all messages
326- return
327-
328- # Ignore the in progress user message
329- num_previous_rounds = message_history .num_user_messages - 1
330- if num_previous_rounds >= max_messages :
331- # Trim history but keep system prompt (first message).
332- # Every other message should be an assistant message, so keep 2x
333- # message objects. Also keep the last in progress user message
334- num_keep = 2 * max_messages + 1
335- drop_index = len (message_history .messages ) - num_keep
336- message_history .messages = [
337- message_history .messages [0 ],
338- * message_history .messages [drop_index :],
339- ]
340-
341106 async def _async_entry_update_listener (
342107 self , hass : HomeAssistant , entry : ConfigEntry
343108 ) -> None :
0 commit comments