Skip to content

Commit f735331

Browse files
authored
Convert Ollama to subentries (home-assistant#147286)
* Convert Ollama to subentries * Add latest changes from Google subentries * Move config entry type to init
1 parent 5a20ef3 commit f735331

File tree

8 files changed

+625
-139
lines changed

8 files changed

+625
-139
lines changed

homeassistant/components/ollama/__init__.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,16 @@
88
import httpx
99
import ollama
1010

11-
from homeassistant.config_entries import ConfigEntry
11+
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
1212
from homeassistant.const import CONF_URL, Platform
1313
from homeassistant.core import HomeAssistant
1414
from homeassistant.exceptions import ConfigEntryNotReady
15-
from homeassistant.helpers import config_validation as cv
15+
from homeassistant.helpers import (
16+
config_validation as cv,
17+
device_registry as dr,
18+
entity_registry as er,
19+
)
20+
from homeassistant.helpers.typing import ConfigType
1621
from homeassistant.util.ssl import get_default_context
1722

1823
from .const import (
@@ -42,8 +47,16 @@
4247
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
4348
PLATFORMS = (Platform.CONVERSATION,)
4449

50+
type OllamaConfigEntry = ConfigEntry[ollama.AsyncClient]
51+
52+
53+
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
54+
"""Set up Ollama."""
55+
await async_migrate_integration(hass)
56+
return True
57+
4558

46-
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
59+
async def async_setup_entry(hass: HomeAssistant, entry: OllamaConfigEntry) -> bool:
4760
"""Set up Ollama from a config entry."""
4861
settings = {**entry.data, **entry.options}
4962
client = ollama.AsyncClient(host=settings[CONF_URL], verify=get_default_context())
@@ -53,8 +66,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
5366
except (TimeoutError, httpx.ConnectError) as err:
5467
raise ConfigEntryNotReady(err) from err
5568

56-
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = client
57-
69+
entry.runtime_data = client
5870
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
5971
return True
6072

@@ -63,5 +75,69 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
6375
"""Unload Ollama."""
6476
if not await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
6577
return False
66-
hass.data[DOMAIN].pop(entry.entry_id)
6778
return True
79+
80+
81+
async def async_migrate_integration(hass: HomeAssistant) -> None:
82+
"""Migrate integration entry structure."""
83+
84+
entries = hass.config_entries.async_entries(DOMAIN)
85+
if not any(entry.version == 1 for entry in entries):
86+
return
87+
88+
api_keys_entries: dict[str, ConfigEntry] = {}
89+
entity_registry = er.async_get(hass)
90+
device_registry = dr.async_get(hass)
91+
92+
for entry in entries:
93+
use_existing = False
94+
subentry = ConfigSubentry(
95+
data=entry.options,
96+
subentry_type="conversation",
97+
title=entry.title,
98+
unique_id=None,
99+
)
100+
if entry.data[CONF_URL] not in api_keys_entries:
101+
use_existing = True
102+
api_keys_entries[entry.data[CONF_URL]] = entry
103+
104+
parent_entry = api_keys_entries[entry.data[CONF_URL]]
105+
106+
hass.config_entries.async_add_subentry(parent_entry, subentry)
107+
conversation_entity = entity_registry.async_get_entity_id(
108+
"conversation",
109+
DOMAIN,
110+
entry.entry_id,
111+
)
112+
if conversation_entity is not None:
113+
entity_registry.async_update_entity(
114+
conversation_entity,
115+
config_entry_id=parent_entry.entry_id,
116+
config_subentry_id=subentry.subentry_id,
117+
new_unique_id=subentry.subentry_id,
118+
)
119+
120+
device = device_registry.async_get_device(
121+
identifiers={(DOMAIN, entry.entry_id)}
122+
)
123+
if device is not None:
124+
device_registry.async_update_device(
125+
device.id,
126+
new_identifiers={(DOMAIN, subentry.subentry_id)},
127+
add_config_subentry_id=subentry.subentry_id,
128+
add_config_entry_id=parent_entry.entry_id,
129+
)
130+
if parent_entry.entry_id != entry.entry_id:
131+
device_registry.async_update_device(
132+
device.id,
133+
remove_config_entry_id=entry.entry_id,
134+
)
135+
136+
if not use_existing:
137+
await hass.config_entries.async_remove(entry.entry_id)
138+
else:
139+
hass.config_entries.async_update_entry(
140+
entry,
141+
options={},
142+
version=2,
143+
)

homeassistant/components/ollama/config_flow.py

Lines changed: 133 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
from homeassistant.config_entries import (
1616
ConfigEntry,
17+
ConfigEntryState,
1718
ConfigFlow,
1819
ConfigFlowResult,
19-
OptionsFlow,
20+
ConfigSubentryFlow,
21+
SubentryFlowResult,
2022
)
21-
from homeassistant.const import CONF_LLM_HASS_API, CONF_URL
22-
from homeassistant.core import HomeAssistant
23+
from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL
24+
from homeassistant.core import HomeAssistant, callback
2325
from homeassistant.helpers import llm
2426
from homeassistant.helpers.selector import (
2527
BooleanSelector,
@@ -70,7 +72,7 @@
7072
class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
7173
"""Handle a config flow for Ollama."""
7274

73-
VERSION = 1
75+
VERSION = 2
7476

7577
def __init__(self) -> None:
7678
"""Initialize config flow."""
@@ -94,6 +96,8 @@ async def async_step_user(
9496

9597
errors = {}
9698

99+
self._async_abort_entries_match({CONF_URL: self.url})
100+
97101
try:
98102
self.client = ollama.AsyncClient(
99103
host=self.url, verify=get_default_context()
@@ -146,8 +150,16 @@ async def async_step_user(
146150
return await self.async_step_download()
147151

148152
return self.async_create_entry(
149-
title=_get_title(self.model),
153+
title=self.url,
150154
data={CONF_URL: self.url, CONF_MODEL: self.model},
155+
subentries=[
156+
{
157+
"subentry_type": "conversation",
158+
"data": {},
159+
"title": _get_title(self.model),
160+
"unique_id": None,
161+
}
162+
],
151163
)
152164

153165
async def async_step_download(
@@ -189,6 +201,14 @@ async def async_step_finish(
189201
return self.async_create_entry(
190202
title=_get_title(self.model),
191203
data={CONF_URL: self.url, CONF_MODEL: self.model},
204+
subentries=[
205+
{
206+
"subentry_type": "conversation",
207+
"data": {},
208+
"title": _get_title(self.model),
209+
"unique_id": None,
210+
}
211+
],
192212
)
193213

194214
async def async_step_failed(
@@ -197,41 +217,62 @@ async def async_step_failed(
197217
"""Step after model downloading has failed."""
198218
return self.async_abort(reason="download_failed")
199219

200-
@staticmethod
201-
def async_get_options_flow(
202-
config_entry: ConfigEntry,
203-
) -> OptionsFlow:
204-
"""Create the options flow."""
205-
return OllamaOptionsFlow(config_entry)
220+
@classmethod
221+
@callback
222+
def async_get_supported_subentry_types(
223+
cls, config_entry: ConfigEntry
224+
) -> dict[str, type[ConfigSubentryFlow]]:
225+
"""Return subentries supported by this integration."""
226+
return {"conversation": ConversationSubentryFlowHandler}
206227

207228

208-
class OllamaOptionsFlow(OptionsFlow):
209-
"""Ollama options flow."""
229+
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
230+
"""Flow for managing conversation subentries."""
210231

211-
def __init__(self, config_entry: ConfigEntry) -> None:
212-
"""Initialize options flow."""
213-
self.url: str = config_entry.data[CONF_URL]
214-
self.model: str = config_entry.data[CONF_MODEL]
232+
@property
233+
def _is_new(self) -> bool:
234+
"""Return if this is a new subentry."""
235+
return self.source == "user"
215236

216-
async def async_step_init(
237+
async def async_step_set_options(
217238
self, user_input: dict[str, Any] | None = None
218-
) -> ConfigFlowResult:
219-
"""Manage the options."""
220-
if user_input is not None:
239+
) -> SubentryFlowResult:
240+
"""Set conversation options."""
241+
# abort if entry is not loaded
242+
if self._get_entry().state != ConfigEntryState.LOADED:
243+
return self.async_abort(reason="entry_not_loaded")
244+
245+
errors: dict[str, str] = {}
246+
247+
if user_input is None:
248+
if self._is_new:
249+
options = {}
250+
else:
251+
options = self._get_reconfigure_subentry().data.copy()
252+
253+
elif self._is_new:
221254
return self.async_create_entry(
222-
title=_get_title(self.model), data=user_input
255+
title=user_input.pop(CONF_NAME),
256+
data=user_input,
257+
)
258+
else:
259+
return self.async_update_and_abort(
260+
self._get_entry(),
261+
self._get_reconfigure_subentry(),
262+
data=user_input,
223263
)
224264

225-
options: Mapping[str, Any] = self.config_entry.options or {}
226-
schema = ollama_config_option_schema(self.hass, options)
265+
schema = ollama_config_option_schema(self.hass, self._is_new, options)
227266
return self.async_show_form(
228-
step_id="init",
229-
data_schema=vol.Schema(schema),
267+
step_id="set_options", data_schema=vol.Schema(schema), errors=errors
230268
)
231269

270+
async_step_user = async_step_set_options
271+
async_step_reconfigure = async_step_set_options
272+
232273

233274
def ollama_config_option_schema(
234-
hass: HomeAssistant, options: Mapping[str, Any]
275+
hass: HomeAssistant, is_new: bool, options: Mapping[str, Any]
235276
) -> dict:
236277
"""Ollama options schema."""
237278
hass_apis: list[SelectOptionDict] = [
@@ -242,54 +283,72 @@ def ollama_config_option_schema(
242283
for api in llm.async_get_apis(hass)
243284
]
244285

245-
return {
246-
vol.Optional(
247-
CONF_PROMPT,
248-
description={
249-
"suggested_value": options.get(
250-
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
286+
if is_new:
287+
schema: dict[vol.Required | vol.Optional, Any] = {
288+
vol.Required(CONF_NAME, default="Ollama Conversation"): str,
289+
}
290+
else:
291+
schema = {}
292+
293+
schema.update(
294+
{
295+
vol.Optional(
296+
CONF_PROMPT,
297+
description={
298+
"suggested_value": options.get(
299+
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
300+
)
301+
},
302+
): TemplateSelector(),
303+
vol.Optional(
304+
CONF_LLM_HASS_API,
305+
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
306+
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
307+
vol.Optional(
308+
CONF_NUM_CTX,
309+
description={
310+
"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)
311+
},
312+
): NumberSelector(
313+
NumberSelectorConfig(
314+
min=MIN_NUM_CTX,
315+
max=MAX_NUM_CTX,
316+
step=1,
317+
mode=NumberSelectorMode.BOX,
251318
)
252-
},
253-
): TemplateSelector(),
254-
vol.Optional(
255-
CONF_LLM_HASS_API,
256-
description={"suggested_value": options.get(CONF_LLM_HASS_API)},
257-
): SelectSelector(SelectSelectorConfig(options=hass_apis, multiple=True)),
258-
vol.Optional(
259-
CONF_NUM_CTX,
260-
description={"suggested_value": options.get(CONF_NUM_CTX, DEFAULT_NUM_CTX)},
261-
): NumberSelector(
262-
NumberSelectorConfig(
263-
min=MIN_NUM_CTX, max=MAX_NUM_CTX, step=1, mode=NumberSelectorMode.BOX
264-
)
265-
),
266-
vol.Optional(
267-
CONF_MAX_HISTORY,
268-
description={
269-
"suggested_value": options.get(CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY)
270-
},
271-
): NumberSelector(
272-
NumberSelectorConfig(
273-
min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
274-
)
275-
),
276-
vol.Optional(
277-
CONF_KEEP_ALIVE,
278-
description={
279-
"suggested_value": options.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)
280-
},
281-
): NumberSelector(
282-
NumberSelectorConfig(
283-
min=-1, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
284-
)
285-
),
286-
vol.Optional(
287-
CONF_THINK,
288-
description={
289-
"suggested_value": options.get("think", DEFAULT_THINK),
290-
},
291-
): BooleanSelector(),
292-
}
319+
),
320+
vol.Optional(
321+
CONF_MAX_HISTORY,
322+
description={
323+
"suggested_value": options.get(
324+
CONF_MAX_HISTORY, DEFAULT_MAX_HISTORY
325+
)
326+
},
327+
): NumberSelector(
328+
NumberSelectorConfig(
329+
min=0, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
330+
)
331+
),
332+
vol.Optional(
333+
CONF_KEEP_ALIVE,
334+
description={
335+
"suggested_value": options.get(CONF_KEEP_ALIVE, DEFAULT_KEEP_ALIVE)
336+
},
337+
): NumberSelector(
338+
NumberSelectorConfig(
339+
min=-1, max=sys.maxsize, step=1, mode=NumberSelectorMode.BOX
340+
)
341+
),
342+
vol.Optional(
343+
CONF_THINK,
344+
description={
345+
"suggested_value": options.get("think", DEFAULT_THINK),
346+
},
347+
): BooleanSelector(),
348+
}
349+
)
350+
351+
return schema
293352

294353

295354
def _get_title(model: str) -> str:

0 commit comments

Comments
 (0)