Skip to content

Commit 8eb906f

Browse files
balloobjoostlek
andauthored
Migrate OpenAI to config subentries (home-assistant#147282)
* Migrate OpenAI to config subentries * Add latest changes from Google subentries * Update homeassistant/components/openai_conversation/__init__.py Co-authored-by: Joost Lekkerkerker <[email protected]> --------- Co-authored-by: Joost Lekkerkerker <[email protected]>
1 parent 4d98431 commit 8eb906f

File tree

10 files changed

+734
-167
lines changed

10 files changed

+734
-167
lines changed

homeassistant/components/openai_conversation/__init__.py

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
import voluptuous as vol
2121

22-
from homeassistant.config_entries import ConfigEntry
22+
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
2323
from homeassistant.const import CONF_API_KEY, Platform
2424
from homeassistant.core import (
2525
HomeAssistant,
@@ -32,7 +32,12 @@
3232
HomeAssistantError,
3333
ServiceValidationError,
3434
)
35-
from homeassistant.helpers import config_validation as cv, selector
35+
from homeassistant.helpers import (
36+
config_validation as cv,
37+
device_registry as dr,
38+
entity_registry as er,
39+
selector,
40+
)
3641
from homeassistant.helpers.httpx_client import get_async_client
3742
from homeassistant.helpers.typing import ConfigType
3843

@@ -73,6 +78,7 @@ def encode_file(file_path: str) -> tuple[str, str]:
7378

7479
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
7580
"""Set up OpenAI Conversation."""
81+
await async_migrate_integration(hass)
7682

7783
async def render_image(call: ServiceCall) -> ServiceResponse:
7884
"""Render an image with dall-e."""
@@ -118,7 +124,21 @@ async def send_prompt(call: ServiceCall) -> ServiceResponse:
118124
translation_placeholders={"config_entry": entry_id},
119125
)
120126

121-
model: str = entry.options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL)
127+
# Get first conversation subentry for options
128+
conversation_subentry = next(
129+
(
130+
sub
131+
for sub in entry.subentries.values()
132+
if sub.subentry_type == "conversation"
133+
),
134+
None,
135+
)
136+
if not conversation_subentry:
137+
raise ServiceValidationError("No conversation configuration found")
138+
139+
model: str = conversation_subentry.data.get(
140+
CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL
141+
)
122142
client: openai.AsyncClient = entry.runtime_data
123143

124144
content: ResponseInputMessageContentListParam = [
@@ -169,11 +189,11 @@ def append_files_to_content() -> None:
169189
model_args = {
170190
"model": model,
171191
"input": messages,
172-
"max_output_tokens": entry.options.get(
192+
"max_output_tokens": conversation_subentry.data.get(
173193
CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS
174194
),
175-
"top_p": entry.options.get(CONF_TOP_P, RECOMMENDED_TOP_P),
176-
"temperature": entry.options.get(
195+
"top_p": conversation_subentry.data.get(CONF_TOP_P, RECOMMENDED_TOP_P),
196+
"temperature": conversation_subentry.data.get(
177197
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE
178198
),
179199
"user": call.context.user_id,
@@ -182,7 +202,7 @@ def append_files_to_content() -> None:
182202

183203
if model.startswith("o"):
184204
model_args["reasoning"] = {
185-
"effort": entry.options.get(
205+
"effort": conversation_subentry.data.get(
186206
CONF_REASONING_EFFORT, RECOMMENDED_REASONING_EFFORT
187207
)
188208
}
@@ -269,3 +289,68 @@ async def async_setup_entry(hass: HomeAssistant, entry: OpenAIConfigEntry) -> bo
269289
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
270290
"""Unload OpenAI."""
271291
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
292+
293+
294+
async def async_migrate_integration(hass: HomeAssistant) -> None:
295+
"""Migrate integration entry structure."""
296+
297+
entries = hass.config_entries.async_entries(DOMAIN)
298+
if not any(entry.version == 1 for entry in entries):
299+
return
300+
301+
api_keys_entries: dict[str, ConfigEntry] = {}
302+
entity_registry = er.async_get(hass)
303+
device_registry = dr.async_get(hass)
304+
305+
for entry in entries:
306+
use_existing = False
307+
subentry = ConfigSubentry(
308+
data=entry.options,
309+
subentry_type="conversation",
310+
title=entry.title,
311+
unique_id=None,
312+
)
313+
if entry.data[CONF_API_KEY] not in api_keys_entries:
314+
use_existing = True
315+
api_keys_entries[entry.data[CONF_API_KEY]] = entry
316+
317+
parent_entry = api_keys_entries[entry.data[CONF_API_KEY]]
318+
319+
hass.config_entries.async_add_subentry(parent_entry, subentry)
320+
conversation_entity = entity_registry.async_get_entity_id(
321+
"conversation",
322+
DOMAIN,
323+
entry.entry_id,
324+
)
325+
if conversation_entity is not None:
326+
entity_registry.async_update_entity(
327+
conversation_entity,
328+
config_entry_id=parent_entry.entry_id,
329+
config_subentry_id=subentry.subentry_id,
330+
new_unique_id=subentry.subentry_id,
331+
)
332+
333+
device = device_registry.async_get_device(
334+
identifiers={(DOMAIN, entry.entry_id)}
335+
)
336+
if device is not None:
337+
device_registry.async_update_device(
338+
device.id,
339+
new_identifiers={(DOMAIN, subentry.subentry_id)},
340+
add_config_subentry_id=subentry.subentry_id,
341+
add_config_entry_id=parent_entry.entry_id,
342+
)
343+
if parent_entry.entry_id != entry.entry_id:
344+
device_registry.async_update_device(
345+
device.id,
346+
remove_config_entry_id=entry.entry_id,
347+
)
348+
349+
if not use_existing:
350+
await hass.config_entries.async_remove(entry.entry_id)
351+
else:
352+
hass.config_entries.async_update_entry(
353+
entry,
354+
options={},
355+
version=2,
356+
)

homeassistant/components/openai_conversation/config_flow.py

Lines changed: 108 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,20 @@
1313
from homeassistant.components.zone import ENTITY_ID_HOME
1414
from homeassistant.config_entries import (
1515
ConfigEntry,
16+
ConfigEntryState,
1617
ConfigFlow,
1718
ConfigFlowResult,
18-
OptionsFlow,
19+
ConfigSubentryFlow,
20+
SubentryFlowResult,
1921
)
2022
from homeassistant.const import (
2123
ATTR_LATITUDE,
2224
ATTR_LONGITUDE,
2325
CONF_API_KEY,
2426
CONF_LLM_HASS_API,
27+
CONF_NAME,
2528
)
26-
from homeassistant.core import HomeAssistant
29+
from homeassistant.core import HomeAssistant, callback
2730
from homeassistant.helpers import llm
2831
from homeassistant.helpers.httpx_client import get_async_client
2932
from homeassistant.helpers.selector import (
@@ -52,6 +55,7 @@
5255
CONF_WEB_SEARCH_REGION,
5356
CONF_WEB_SEARCH_TIMEZONE,
5457
CONF_WEB_SEARCH_USER_LOCATION,
58+
DEFAULT_CONVERSATION_NAME,
5559
DOMAIN,
5660
RECOMMENDED_CHAT_MODEL,
5761
RECOMMENDED_MAX_TOKENS,
@@ -94,7 +98,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> None:
9498
class OpenAIConfigFlow(ConfigFlow, domain=DOMAIN):
9599
"""Handle a config flow for OpenAI Conversation."""
96100

97-
VERSION = 1
101+
VERSION = 2
98102

99103
async def async_step_user(
100104
self, user_input: dict[str, Any] | None = None
@@ -107,6 +111,7 @@ async def async_step_user(
107111

108112
errors: dict[str, str] = {}
109113

114+
self._async_abort_entries_match(user_input)
110115
try:
111116
await validate_input(self.hass, user_input)
112117
except openai.APIConnectionError:
@@ -120,32 +125,61 @@ async def async_step_user(
120125
return self.async_create_entry(
121126
title="ChatGPT",
122127
data=user_input,
123-
options=RECOMMENDED_OPTIONS,
128+
subentries=[
129+
{
130+
"subentry_type": "conversation",
131+
"data": RECOMMENDED_OPTIONS,
132+
"title": DEFAULT_CONVERSATION_NAME,
133+
"unique_id": None,
134+
}
135+
],
124136
)
125137

126138
return self.async_show_form(
127139
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
128140
)
129141

130-
@staticmethod
131-
def async_get_options_flow(
132-
config_entry: ConfigEntry,
133-
) -> OptionsFlow:
134-
"""Create the options flow."""
135-
return OpenAIOptionsFlow(config_entry)
142+
@classmethod
143+
@callback
144+
def async_get_supported_subentry_types(
145+
cls, config_entry: ConfigEntry
146+
) -> dict[str, type[ConfigSubentryFlow]]:
147+
"""Return subentries supported by this integration."""
148+
return {"conversation": ConversationSubentryFlowHandler}
149+
136150

151+
class ConversationSubentryFlowHandler(ConfigSubentryFlow):
152+
"""Flow for managing conversation subentries."""
137153

138-
class OpenAIOptionsFlow(OptionsFlow):
139-
"""OpenAI config flow options handler."""
154+
last_rendered_recommended = False
155+
options: dict[str, Any]
156+
157+
@property
158+
def _is_new(self) -> bool:
159+
"""Return if this is a new subentry."""
160+
return self.source == "user"
161+
162+
async def async_step_user(
163+
self, user_input: dict[str, Any] | None = None
164+
) -> SubentryFlowResult:
165+
"""Add a subentry."""
166+
self.options = RECOMMENDED_OPTIONS.copy()
167+
return await self.async_step_init()
140168

141-
def __init__(self, config_entry: ConfigEntry) -> None:
142-
"""Initialize options flow."""
143-
self.options = config_entry.options.copy()
169+
async def async_step_reconfigure(
170+
self, user_input: dict[str, Any] | None = None
171+
) -> SubentryFlowResult:
172+
"""Handle reconfiguration of a subentry."""
173+
self.options = self._get_reconfigure_subentry().data.copy()
174+
return await self.async_step_init()
144175

145176
async def async_step_init(
146177
self, user_input: dict[str, Any] | None = None
147-
) -> ConfigFlowResult:
178+
) -> SubentryFlowResult:
148179
"""Manage initial options."""
180+
# abort if entry is not loaded
181+
if self._get_entry().state != ConfigEntryState.LOADED:
182+
return self.async_abort(reason="entry_not_loaded")
149183
options = self.options
150184

151185
hass_apis: list[SelectOptionDict] = [
@@ -160,25 +194,47 @@ async def async_step_init(
160194
):
161195
options[CONF_LLM_HASS_API] = [suggested_llm_apis]
162196

163-
step_schema: VolDictType = {
164-
vol.Optional(
165-
CONF_PROMPT,
166-
description={"suggested_value": llm.DEFAULT_INSTRUCTIONS_PROMPT},
167-
): TemplateSelector(),
168-
vol.Optional(CONF_LLM_HASS_API): SelectSelector(
169-
SelectSelectorConfig(options=hass_apis, multiple=True)
170-
),
171-
vol.Required(
172-
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
173-
): bool,
174-
}
197+
step_schema: VolDictType = {}
198+
199+
if self._is_new:
200+
step_schema[vol.Required(CONF_NAME, default=DEFAULT_CONVERSATION_NAME)] = (
201+
str
202+
)
203+
204+
step_schema.update(
205+
{
206+
vol.Optional(
207+
CONF_PROMPT,
208+
description={
209+
"suggested_value": options.get(
210+
CONF_PROMPT, llm.DEFAULT_INSTRUCTIONS_PROMPT
211+
)
212+
},
213+
): TemplateSelector(),
214+
vol.Optional(CONF_LLM_HASS_API): SelectSelector(
215+
SelectSelectorConfig(options=hass_apis, multiple=True)
216+
),
217+
vol.Required(
218+
CONF_RECOMMENDED, default=options.get(CONF_RECOMMENDED, False)
219+
): bool,
220+
}
221+
)
175222

176223
if user_input is not None:
177224
if not user_input.get(CONF_LLM_HASS_API):
178225
user_input.pop(CONF_LLM_HASS_API, None)
179226

180227
if user_input[CONF_RECOMMENDED]:
181-
return self.async_create_entry(title="", data=user_input)
228+
if self._is_new:
229+
return self.async_create_entry(
230+
title=user_input.pop(CONF_NAME),
231+
data=user_input,
232+
)
233+
return self.async_update_and_abort(
234+
self._get_entry(),
235+
self._get_reconfigure_subentry(),
236+
data=user_input,
237+
)
182238

183239
options.update(user_input)
184240
if CONF_LLM_HASS_API in options and CONF_LLM_HASS_API not in user_input:
@@ -194,7 +250,7 @@ async def async_step_init(
194250

195251
async def async_step_advanced(
196252
self, user_input: dict[str, Any] | None = None
197-
) -> ConfigFlowResult:
253+
) -> SubentryFlowResult:
198254
"""Manage advanced options."""
199255
options = self.options
200256
errors: dict[str, str] = {}
@@ -236,7 +292,7 @@ async def async_step_advanced(
236292

237293
async def async_step_model(
238294
self, user_input: dict[str, Any] | None = None
239-
) -> ConfigFlowResult:
295+
) -> SubentryFlowResult:
240296
"""Manage model-specific options."""
241297
options = self.options
242298
errors: dict[str, str] = {}
@@ -303,7 +359,16 @@ async def async_step_model(
303359
}
304360

305361
if not step_schema:
306-
return self.async_create_entry(title="", data=options)
362+
if self._is_new:
363+
return self.async_create_entry(
364+
title=options.pop(CONF_NAME, DEFAULT_CONVERSATION_NAME),
365+
data=options,
366+
)
367+
return self.async_update_and_abort(
368+
self._get_entry(),
369+
self._get_reconfigure_subentry(),
370+
data=options,
371+
)
307372

308373
if user_input is not None:
309374
if user_input.get(CONF_WEB_SEARCH):
@@ -316,7 +381,16 @@ async def async_step_model(
316381
options.pop(CONF_WEB_SEARCH_TIMEZONE, None)
317382

318383
options.update(user_input)
319-
return self.async_create_entry(title="", data=options)
384+
if self._is_new:
385+
return self.async_create_entry(
386+
title=options.pop(CONF_NAME, DEFAULT_CONVERSATION_NAME),
387+
data=options,
388+
)
389+
return self.async_update_and_abort(
390+
self._get_entry(),
391+
self._get_reconfigure_subentry(),
392+
data=options,
393+
)
320394

321395
return self.async_show_form(
322396
step_id="model",
@@ -332,7 +406,7 @@ async def _get_location_data(self) -> dict[str, str]:
332406
zone_home = self.hass.states.get(ENTITY_ID_HOME)
333407
if zone_home is not None:
334408
client = openai.AsyncOpenAI(
335-
api_key=self.config_entry.data[CONF_API_KEY],
409+
api_key=self._get_entry().data[CONF_API_KEY],
336410
http_client=get_async_client(self.hass),
337411
)
338412
location_schema = vol.Schema(

0 commit comments

Comments
 (0)