Skip to content

Commit 4b5c04b

Browse files
allenporterballoob
andauthored
Add AI Task support in Ollama (home-assistant#148226)
Co-authored-by: Paulus Schoutsen <[email protected]>
1 parent 8cb9cad commit 4b5c04b

File tree

11 files changed

+635
-45
lines changed

11 files changed

+635
-45
lines changed

homeassistant/components/ollama/__init__.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
CONF_NUM_CTX,
2929
CONF_PROMPT,
3030
CONF_THINK,
31+
DEFAULT_AI_TASK_NAME,
3132
DEFAULT_NAME,
3233
DEFAULT_TIMEOUT,
3334
DOMAIN,
@@ -47,7 +48,7 @@
4748
]
4849

4950
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
50-
PLATFORMS = (Platform.CONVERSATION,)
51+
PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION)
5152

5253
type OllamaConfigEntry = ConfigEntry[ollama.AsyncClient]
5354

@@ -118,6 +119,7 @@ async def async_migrate_integration(hass: HomeAssistant) -> None:
118119
parent_entry = api_keys_entries[entry.data[CONF_URL]]
119120

120121
hass.config_entries.async_add_subentry(parent_entry, subentry)
122+
121123
conversation_entity = entity_registry.async_get_entity_id(
122124
"conversation",
123125
DOMAIN,
@@ -208,6 +210,31 @@ async def async_migrate_entry(hass: HomeAssistant, entry: OllamaConfigEntry) ->
208210
minor_version=1,
209211
)
210212

213+
if entry.version == 3 and entry.minor_version == 1:
214+
# Add AI Task subentry with default options. We can only create a new
215+
# subentry if we can find an existing model in the entry. The model
216+
# was removed in the previous migration step, so we need to
217+
# check the subentries for an existing model.
218+
existing_model = next(
219+
iter(
220+
model
221+
for subentry in entry.subentries.values()
222+
if (model := subentry.data.get(CONF_MODEL)) is not None
223+
),
224+
None,
225+
)
226+
if existing_model:
227+
hass.config_entries.async_add_subentry(
228+
entry,
229+
ConfigSubentry(
230+
data=MappingProxyType({CONF_MODEL: existing_model}),
231+
subentry_type="ai_task_data",
232+
title=DEFAULT_AI_TASK_NAME,
233+
unique_id=None,
234+
),
235+
)
236+
hass.config_entries.async_update_entry(entry, minor_version=2)
237+
211238
_LOGGER.debug(
212239
"Migration to version %s:%s successful", entry.version, entry.minor_version
213240
)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""AI Task integration for Ollama."""
2+
3+
from __future__ import annotations
4+
5+
from json import JSONDecodeError
6+
import logging
7+
8+
from homeassistant.components import ai_task, conversation
9+
from homeassistant.config_entries import ConfigEntry
10+
from homeassistant.core import HomeAssistant
11+
from homeassistant.exceptions import HomeAssistantError
12+
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
13+
from homeassistant.util.json import json_loads
14+
15+
from .entity import OllamaBaseLLMEntity
16+
17+
_LOGGER = logging.getLogger(__name__)
18+
19+
20+
async def async_setup_entry(
21+
hass: HomeAssistant,
22+
config_entry: ConfigEntry,
23+
async_add_entities: AddConfigEntryEntitiesCallback,
24+
) -> None:
25+
"""Set up AI Task entities."""
26+
for subentry in config_entry.subentries.values():
27+
if subentry.subentry_type != "ai_task_data":
28+
continue
29+
30+
async_add_entities(
31+
[OllamaTaskEntity(config_entry, subentry)],
32+
config_subentry_id=subentry.subentry_id,
33+
)
34+
35+
36+
class OllamaTaskEntity(
37+
ai_task.AITaskEntity,
38+
OllamaBaseLLMEntity,
39+
):
40+
"""Ollama AI Task entity."""
41+
42+
_attr_supported_features = ai_task.AITaskEntityFeature.GENERATE_DATA
43+
44+
async def _async_generate_data(
45+
self,
46+
task: ai_task.GenDataTask,
47+
chat_log: conversation.ChatLog,
48+
) -> ai_task.GenDataTaskResult:
49+
"""Handle a generate data task."""
50+
await self._async_handle_chat_log(chat_log, task.structure)
51+
52+
if not isinstance(chat_log.content[-1], conversation.AssistantContent):
53+
raise HomeAssistantError(
54+
"Last content in chat log is not an AssistantContent"
55+
)
56+
57+
text = chat_log.content[-1].content or ""
58+
59+
if not task.structure:
60+
return ai_task.GenDataTaskResult(
61+
conversation_id=chat_log.conversation_id,
62+
data=text,
63+
)
64+
try:
65+
data = json_loads(text)
66+
except JSONDecodeError as err:
67+
_LOGGER.error(
68+
"Failed to parse JSON response: %s. Response: %s",
69+
err,
70+
text,
71+
)
72+
raise HomeAssistantError("Error with Ollama structured response") from err
73+
74+
return ai_task.GenDataTaskResult(
75+
conversation_id=chat_log.conversation_id,
76+
data=data,
77+
)

homeassistant/components/ollama/config_flow.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
CONF_NUM_CTX,
4747
CONF_PROMPT,
4848
CONF_THINK,
49+
DEFAULT_AI_TASK_NAME,
50+
DEFAULT_CONVERSATION_NAME,
4951
DEFAULT_KEEP_ALIVE,
5052
DEFAULT_MAX_HISTORY,
5153
DEFAULT_MODEL,
@@ -74,7 +76,7 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
7476
"""Handle a config flow for Ollama."""
7577

7678
VERSION = 3
77-
MINOR_VERSION = 1
79+
MINOR_VERSION = 2
7880

7981
def __init__(self) -> None:
8082
"""Initialize config flow."""
@@ -136,11 +138,14 @@ def async_get_supported_subentry_types(
136138
cls, config_entry: ConfigEntry
137139
) -> dict[str, type[ConfigSubentryFlow]]:
138140
"""Return subentries supported by this integration."""
139-
return {"conversation": ConversationSubentryFlowHandler}
141+
return {
142+
"conversation": OllamaSubentryFlowHandler,
143+
"ai_task_data": OllamaSubentryFlowHandler,
144+
}
140145

141146

142-
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
143-
"""Flow for managing conversation subentries."""
147+
class OllamaSubentryFlowHandler(ConfigSubentryFlow):
148+
"""Flow for managing Ollama subentries."""
144149

145150
def __init__(self) -> None:
146151
"""Initialize the subentry flow."""
@@ -201,7 +206,11 @@ async def async_step_set_options(
201206
step_id="set_options",
202207
data_schema=vol.Schema(
203208
ollama_config_option_schema(
204-
self.hass, self._is_new, options, models_to_list
209+
self.hass,
210+
self._is_new,
211+
self._subentry_type,
212+
options,
213+
models_to_list,
205214
)
206215
),
207216
)
@@ -300,13 +309,19 @@ async def async_step_finish(
300309
def ollama_config_option_schema(
301310
hass: HomeAssistant,
302311
is_new: bool,
312+
subentry_type: str,
303313
options: Mapping[str, Any],
304314
models_to_list: list[SelectOptionDict],
305315
) -> dict:
306316
"""Ollama options schema."""
307317
if is_new:
318+
if subentry_type == "ai_task_data":
319+
default_name = DEFAULT_AI_TASK_NAME
320+
else:
321+
default_name = DEFAULT_CONVERSATION_NAME
322+
308323
schema: dict = {
309-
vol.Required(CONF_NAME, default="Ollama Conversation"): str,
324+
vol.Required(CONF_NAME, default=default_name): str,
310325
}
311326
else:
312327
schema = {}
@@ -319,29 +334,38 @@ def ollama_config_option_schema(
319334
): SelectSelector(
320335
SelectSelectorConfig(options=models_to_list, custom_value=True)
321336
),
322-
vol.Optional(
323-
CONF_PROMPT,
324-
description={
325-
"suggested_value": options.get(
326-
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
327-
)
328-
},
329-
): TemplateSelector(),
330-
vol.Optional(
331-
CONF_LLM_HASS_API,
332-
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
333-
): SelectSelector(
334-
SelectSelectorConfig(
335-
options=[
336-
SelectOptionDict(
337-
label=api.name,
338-
value=api.id,
337+
}
338+
)
339+
if subentry_type == "conversation":
340+
schema.update(
341+
{
342+
vol.Optional(
343+
CONF_PROMPT,
344+
description={
345+
"suggested_value": options.get(
346+
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
339347
)
340-
for api in llm.async_get_apis(hass)
341-
],
342-
multiple=True,
343-
)
344-
),
348+
},
349+
): TemplateSelector(),
350+
vol.Optional(
351+
CONF_LLM_HASS_API,
352+
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
353+
): SelectSelector(
354+
SelectSelectorConfig(
355+
options=[
356+
SelectOptionDict(
357+
label=api.name,
358+
value=api.id,
359+
)
360+
for api in llm.async_get_apis(hass)
361+
],
362+
multiple=True,
363+
)
364+
),
365+
}
366+
)
367+
schema.update(
368+
{
345369
vol.Optional(
346370
CONF_NUM_CTX,
347371
description={

homeassistant/components/ollama/const.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,10 @@
159159
"zephyr",
160160
]
161161
DEFAULT_MODEL = "llama3.2:latest"
162+
163+
DEFAULT_CONVERSATION_NAME = "Ollama Conversation"
164+
DEFAULT_AI_TASK_NAME = "Ollama AI Task"
165+
166+
RECOMMENDED_CONVERSATION_OPTIONS = {
167+
CONF_MAX_HISTORY: DEFAULT_MAX_HISTORY,
168+
}

homeassistant/components/ollama/entity.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any
99

1010
import ollama
11+
import voluptuous as vol
1112
from voluptuous_openapi import convert
1213

1314
from homeassistant.components import conversation
@@ -180,6 +181,7 @@ def __init__(self, entry: OllamaConfigEntry, subentry: ConfigSubentry) -> None:
180181
async def _async_handle_chat_log(
181182
self,
182183
chat_log: conversation.ChatLog,
184+
structure: vol.Schema | None = None,
183185
) -> None:
184186
"""Generate an answer for the chat log."""
185187
settings = {**self.entry.data, **self.subentry.data}
@@ -200,6 +202,17 @@ async def _async_handle_chat_log(
200202
max_messages = int(settings.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY))
201203
self._trim_history(message_history, max_messages)
202204

205+
output_format: dict[str, Any] | None = None
206+
if structure:
207+
output_format = convert(
208+
structure,
209+
custom_serializer=(
210+
chat_log.llm_api.custom_serializer
211+
if chat_log.llm_api
212+
else llm.selector_serializer
213+
),
214+
)
215+
203216
# Get response
204217
# To prevent infinite loops, we limit the number of iterations
205218
for _iteration in range(MAX_TOOL_ITERATIONS):
@@ -214,6 +227,7 @@ async def _async_handle_chat_log(
214227
keep_alive=f"{settings.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)}s",
215228
options={CONF_NUM_CTX: settings.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
216229
think=settings.get(CONF_THINK),
230+
format=output_format,
217231
)
218232
except (ollama.RequestError, ollama.ResponseError) as err:
219233
_LOGGER.error("Unexpected error talking to Ollama server: %s", err)

homeassistant/components/ollama/strings.json

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,44 @@
5555
"progress": {
5656
"download": "Please wait while the model is downloaded, which may take a very long time. Check your Ollama server logs for more details."
5757
}
58+
},
59+
"ai_task_data": {
60+
"initiate_flow": {
61+
"user": "Add Generate data with AI service",
62+
"reconfigure": "Reconfigure Generate data with AI service"
63+
},
64+
"entry_type": "Generate data with AI service",
65+
"step": {
66+
"set_options": {
67+
"data": {
68+
"model": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::model%]",
69+
"name": "[%key:common::config_flow::data::name%]",
70+
"prompt": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::prompt%]",
71+
"max_history": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::max_history%]",
72+
"num_ctx": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::num_ctx%]",
73+
"keep_alive": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::keep_alive%]",
74+
"think": "[%key:component::ollama::config_subentries::conversation::step::set_options::data::think%]"
75+
},
76+
"data_description": {
77+
"prompt": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::prompt%]",
78+
"keep_alive": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::keep_alive%]",
79+
"num_ctx": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::num_ctx%]",
80+
"think": "[%key:component::ollama::config_subentries::conversation::step::set_options::data_description::think%]"
81+
}
82+
},
83+
"download": {
84+
"title": "[%key:component::ollama::config_subentries::conversation::step::download::title%]"
85+
}
86+
},
87+
"abort": {
88+
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
89+
"entry_not_loaded": "[%key:component::ollama::config_subentries::conversation::abort::entry_not_loaded%]",
90+
"download_failed": "[%key:component::ollama::config_subentries::conversation::abort::download_failed%]",
91+
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]"
92+
},
93+
"progress": {
94+
"download": "[%key:component::ollama::config_subentries::conversation::progress::download%]"
95+
}
5896
}
5997
}
6098
}

tests/components/ollama/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,8 @@
1212
ollama.CONF_MAX_HISTORY: 2,
1313
ollama.CONF_MODEL: "test_model:latest",
1414
}
15+
16+
TEST_AI_TASK_OPTIONS = {
17+
ollama.CONF_MAX_HISTORY: 2,
18+
ollama.CONF_MODEL: "test_model:latest",
19+
}

0 commit comments

Comments
 (0)