Skip to content

Commit 1428b41

Browse files
Improve sql config flow (home-assistant#150757)
Co-authored-by: Joost Lekkerkerker <[email protected]>
1 parent e65b429 commit 1428b41

File tree

10 files changed

+966
-756
lines changed

10 files changed

+966
-756
lines changed

homeassistant/components/sql/__init__.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import logging
6+
from typing import Any
67

78
import sqlparse
89
import voluptuous as vol
@@ -32,7 +33,13 @@
3233
)
3334
from homeassistant.helpers.typing import ConfigType
3435

35-
from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN, PLATFORMS
36+
from .const import (
37+
CONF_ADVANCED_OPTIONS,
38+
CONF_COLUMN_NAME,
39+
CONF_QUERY,
40+
DOMAIN,
41+
PLATFORMS,
42+
)
3643
from .util import redact_credentials
3744

3845
_LOGGER = logging.getLogger(__name__)
@@ -75,18 +82,6 @@ def validate_sql_select(value: str) -> str:
7582
)
7683

7784

78-
def remove_configured_db_url_if_not_needed(
79-
hass: HomeAssistant, entry: ConfigEntry
80-
) -> None:
81-
"""Remove db url from config if it matches recorder database."""
82-
hass.config_entries.async_update_entry(
83-
entry,
84-
options={
85-
key: value for key, value in entry.options.items() if key != CONF_DB_URL
86-
},
87-
)
88-
89-
9085
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
9186
"""Set up SQL from yaml config."""
9287
if (conf := config.get(DOMAIN)) is None:
@@ -107,8 +102,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
107102
redact_credentials(entry.options.get(CONF_DB_URL)),
108103
redact_credentials(get_instance(hass).db_url),
109104
)
110-
if entry.options.get(CONF_DB_URL) == get_instance(hass).db_url:
111-
remove_configured_db_url_if_not_needed(hass, entry)
112105

113106
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
114107

@@ -119,3 +112,47 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
119112
"""Unload SQL config entry."""
120113

121114
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
115+
116+
117+
async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
118+
"""Migrate old entry."""
119+
_LOGGER.debug("Migrating from version %s.%s", entry.version, entry.minor_version)
120+
121+
if entry.version > 1:
122+
# This means the user has downgraded from a future version
123+
return False
124+
125+
if entry.version == 1:
126+
old_options = {**entry.options}
127+
new_data = {}
128+
new_options: dict[str, Any] = {}
129+
130+
if (db_url := old_options.get(CONF_DB_URL)) and db_url != get_instance(
131+
hass
132+
).db_url:
133+
new_data[CONF_DB_URL] = db_url
134+
135+
new_options[CONF_COLUMN_NAME] = old_options.get(CONF_COLUMN_NAME)
136+
new_options[CONF_QUERY] = old_options.get(CONF_QUERY)
137+
new_options[CONF_ADVANCED_OPTIONS] = {}
138+
139+
for key in (
140+
CONF_VALUE_TEMPLATE,
141+
CONF_UNIT_OF_MEASUREMENT,
142+
CONF_DEVICE_CLASS,
143+
CONF_STATE_CLASS,
144+
):
145+
if (value := old_options.get(key)) is not None:
146+
new_options[CONF_ADVANCED_OPTIONS][key] = value
147+
148+
hass.config_entries.async_update_entry(
149+
entry, data=new_data, options=new_options, version=2
150+
)
151+
152+
_LOGGER.debug(
153+
"Migration to version %s.%s successful",
154+
entry.version,
155+
entry.minor_version,
156+
)
157+
158+
return True

homeassistant/components/sql/config_flow.py

Lines changed: 115 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any
77

88
import sqlalchemy
9-
from sqlalchemy.engine import Result
9+
from sqlalchemy.engine import Engine, Result
1010
from sqlalchemy.exc import MultipleResultsFound, NoSuchColumnError, SQLAlchemyError
1111
from sqlalchemy.orm import Session, scoped_session, sessionmaker
1212
import sqlparse
@@ -32,59 +32,59 @@
3232
CONF_VALUE_TEMPLATE,
3333
)
3434
from homeassistant.core import callback
35+
from homeassistant.data_entry_flow import section
3536
from homeassistant.helpers import selector
3637

37-
from .const import CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
38+
from .const import CONF_ADVANCED_OPTIONS, CONF_COLUMN_NAME, CONF_QUERY, DOMAIN
3839
from .util import resolve_db_url
3940

4041
_LOGGER = logging.getLogger(__name__)
4142

4243

4344
OPTIONS_SCHEMA: vol.Schema = vol.Schema(
4445
{
45-
vol.Optional(
46-
CONF_DB_URL,
47-
): selector.TextSelector(),
48-
vol.Required(
49-
CONF_COLUMN_NAME,
50-
): selector.TextSelector(),
51-
vol.Required(
52-
CONF_QUERY,
53-
): selector.TextSelector(selector.TextSelectorConfig(multiline=True)),
54-
vol.Optional(
55-
CONF_UNIT_OF_MEASUREMENT,
56-
): selector.TextSelector(),
57-
vol.Optional(
58-
CONF_VALUE_TEMPLATE,
59-
): selector.TemplateSelector(),
60-
vol.Optional(CONF_DEVICE_CLASS): selector.SelectSelector(
61-
selector.SelectSelectorConfig(
62-
options=[
63-
cls.value
64-
for cls in SensorDeviceClass
65-
if cls != SensorDeviceClass.ENUM
66-
],
67-
mode=selector.SelectSelectorMode.DROPDOWN,
68-
translation_key="device_class",
69-
sort=True,
70-
)
46+
vol.Required(CONF_QUERY): selector.TextSelector(
47+
selector.TextSelectorConfig(multiline=True)
7148
),
72-
vol.Optional(CONF_STATE_CLASS): selector.SelectSelector(
73-
selector.SelectSelectorConfig(
74-
options=[cls.value for cls in SensorStateClass],
75-
mode=selector.SelectSelectorMode.DROPDOWN,
76-
translation_key="state_class",
77-
sort=True,
78-
)
49+
vol.Required(CONF_COLUMN_NAME): selector.TextSelector(),
50+
vol.Required(CONF_ADVANCED_OPTIONS): section(
51+
vol.Schema(
52+
{
53+
vol.Optional(CONF_VALUE_TEMPLATE): selector.TemplateSelector(),
54+
vol.Optional(CONF_UNIT_OF_MEASUREMENT): selector.TextSelector(),
55+
vol.Optional(CONF_DEVICE_CLASS): selector.SelectSelector(
56+
selector.SelectSelectorConfig(
57+
options=[
58+
cls.value
59+
for cls in SensorDeviceClass
60+
if cls != SensorDeviceClass.ENUM
61+
],
62+
mode=selector.SelectSelectorMode.DROPDOWN,
63+
translation_key="device_class",
64+
sort=True,
65+
)
66+
),
67+
vol.Optional(CONF_STATE_CLASS): selector.SelectSelector(
68+
selector.SelectSelectorConfig(
69+
options=[cls.value for cls in SensorStateClass],
70+
mode=selector.SelectSelectorMode.DROPDOWN,
71+
translation_key="state_class",
72+
sort=True,
73+
)
74+
),
75+
}
76+
),
77+
{"collapsed": True},
7978
),
8079
}
8180
)
8281

8382
CONFIG_SCHEMA: vol.Schema = vol.Schema(
8483
{
8584
vol.Required(CONF_NAME, default="Select SQL Query"): selector.TextSelector(),
85+
vol.Optional(CONF_DB_URL): selector.TextSelector(),
8686
}
87-
).extend(OPTIONS_SCHEMA.schema)
87+
)
8888

8989

9090
def validate_sql_select(value: str) -> str:
@@ -99,6 +99,31 @@ def validate_sql_select(value: str) -> str:
9999
return str(query[0])
100100

101101

102+
def validate_db_connection(db_url: str) -> bool:
103+
"""Validate db connection."""
104+
105+
engine: Engine | None = None
106+
sess: Session | None = None
107+
try:
108+
engine = sqlalchemy.create_engine(db_url, future=True)
109+
sessmaker = scoped_session(sessionmaker(bind=engine, future=True))
110+
sess = sessmaker()
111+
sess.execute(sqlalchemy.text("select 1 as value"))
112+
except SQLAlchemyError as error:
113+
_LOGGER.debug("Execution error %s", error)
114+
if sess:
115+
sess.close()
116+
if engine:
117+
engine.dispose()
118+
raise
119+
120+
if sess:
121+
sess.close()
122+
engine.dispose()
123+
124+
return True
125+
126+
102127
def validate_query(db_url: str, query: str, column: str) -> bool:
103128
"""Validate SQL query."""
104129

@@ -136,7 +161,9 @@ def validate_query(db_url: str, query: str, column: str) -> bool:
136161
class SQLConfigFlow(ConfigFlow, domain=DOMAIN):
137162
"""Handle a config flow for SQL integration."""
138163

139-
VERSION = 1
164+
VERSION = 2
165+
166+
data: dict[str, Any]
140167

141168
@staticmethod
142169
@callback
@@ -151,17 +178,46 @@ async def async_step_user(
151178
) -> ConfigFlowResult:
152179
"""Handle the user step."""
153180
errors = {}
154-
description_placeholders = {}
155181

156182
if user_input is not None:
157183
db_url = user_input.get(CONF_DB_URL)
184+
185+
try:
186+
db_url_for_validation = resolve_db_url(self.hass, db_url)
187+
await self.hass.async_add_executor_job(
188+
validate_db_connection, db_url_for_validation
189+
)
190+
except SQLAlchemyError:
191+
errors["db_url"] = "db_url_invalid"
192+
193+
if not errors:
194+
self.data = {CONF_NAME: user_input[CONF_NAME]}
195+
if db_url and db_url_for_validation != get_instance(self.hass).db_url:
196+
self.data[CONF_DB_URL] = db_url
197+
return await self.async_step_options()
198+
199+
return self.async_show_form(
200+
step_id="user",
201+
data_schema=self.add_suggested_values_to_schema(CONFIG_SCHEMA, user_input),
202+
errors=errors,
203+
)
204+
205+
async def async_step_options(
206+
self, user_input: dict[str, Any] | None = None
207+
) -> ConfigFlowResult:
208+
"""Handle the user step."""
209+
errors = {}
210+
description_placeholders = {}
211+
212+
if user_input is not None:
158213
query = user_input[CONF_QUERY]
159214
column = user_input[CONF_COLUMN_NAME]
160-
db_url_for_validation = None
161215

162216
try:
163217
query = validate_sql_select(query)
164-
db_url_for_validation = resolve_db_url(self.hass, db_url)
218+
db_url_for_validation = resolve_db_url(
219+
self.hass, self.data.get(CONF_DB_URL)
220+
)
165221
await self.hass.async_add_executor_job(
166222
validate_query, db_url_for_validation, query, column
167223
)
@@ -178,32 +234,25 @@ async def async_step_user(
178234
_LOGGER.debug("Invalid query: %s", err)
179235
errors["query"] = "query_invalid"
180236

181-
options = {
182-
CONF_QUERY: query,
183-
CONF_COLUMN_NAME: column,
184-
CONF_NAME: user_input[CONF_NAME],
237+
mod_advanced_options = {
238+
k: v
239+
for k, v in user_input[CONF_ADVANCED_OPTIONS].items()
240+
if v is not None
185241
}
186-
if uom := user_input.get(CONF_UNIT_OF_MEASUREMENT):
187-
options[CONF_UNIT_OF_MEASUREMENT] = uom
188-
if value_template := user_input.get(CONF_VALUE_TEMPLATE):
189-
options[CONF_VALUE_TEMPLATE] = value_template
190-
if device_class := user_input.get(CONF_DEVICE_CLASS):
191-
options[CONF_DEVICE_CLASS] = device_class
192-
if state_class := user_input.get(CONF_STATE_CLASS):
193-
options[CONF_STATE_CLASS] = state_class
194-
if db_url_for_validation != get_instance(self.hass).db_url:
195-
options[CONF_DB_URL] = db_url_for_validation
242+
user_input[CONF_ADVANCED_OPTIONS] = mod_advanced_options
196243

197244
if not errors:
245+
name = self.data[CONF_NAME]
246+
self.data.pop(CONF_NAME)
198247
return self.async_create_entry(
199-
title=user_input[CONF_NAME],
200-
data={},
201-
options=options,
248+
title=name,
249+
data=self.data,
250+
options=user_input,
202251
)
203252

204253
return self.async_show_form(
205-
step_id="user",
206-
data_schema=self.add_suggested_values_to_schema(CONFIG_SCHEMA, user_input),
254+
step_id="options",
255+
data_schema=self.add_suggested_values_to_schema(OPTIONS_SCHEMA, user_input),
207256
errors=errors,
208257
description_placeholders=description_placeholders,
209258
)
@@ -220,10 +269,9 @@ async def async_step_init(
220269
description_placeholders = {}
221270

222271
if user_input is not None:
223-
db_url = user_input.get(CONF_DB_URL)
272+
db_url = self.config_entry.data.get(CONF_DB_URL)
224273
query = user_input[CONF_QUERY]
225274
column = user_input[CONF_COLUMN_NAME]
226-
name = self.config_entry.options.get(CONF_NAME, self.config_entry.title)
227275

228276
try:
229277
query = validate_sql_select(query)
@@ -252,24 +300,15 @@ async def async_step_init(
252300
recorder_db,
253301
)
254302

255-
options = {
256-
CONF_QUERY: query,
257-
CONF_COLUMN_NAME: column,
258-
CONF_NAME: name,
303+
mod_advanced_options = {
304+
k: v
305+
for k, v in user_input[CONF_ADVANCED_OPTIONS].items()
306+
if v is not None
259307
}
260-
if uom := user_input.get(CONF_UNIT_OF_MEASUREMENT):
261-
options[CONF_UNIT_OF_MEASUREMENT] = uom
262-
if value_template := user_input.get(CONF_VALUE_TEMPLATE):
263-
options[CONF_VALUE_TEMPLATE] = value_template
264-
if device_class := user_input.get(CONF_DEVICE_CLASS):
265-
options[CONF_DEVICE_CLASS] = device_class
266-
if state_class := user_input.get(CONF_STATE_CLASS):
267-
options[CONF_STATE_CLASS] = state_class
268-
if db_url_for_validation != get_instance(self.hass).db_url:
269-
options[CONF_DB_URL] = db_url_for_validation
308+
user_input[CONF_ADVANCED_OPTIONS] = mod_advanced_options
270309

271310
return self.async_create_entry(
272-
data=options,
311+
data=user_input,
273312
)
274313

275314
return self.async_show_form(

homeassistant/components/sql/const.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99

1010
CONF_COLUMN_NAME = "column"
1111
CONF_QUERY = "query"
12+
CONF_ADVANCED_OPTIONS = "advanced_options"
1213
DB_URL_RE = re.compile("//.*:.*@")

0 commit comments

Comments
 (0)