66from typing import Any
77
88import sqlalchemy
9- from sqlalchemy .engine import Result
9+ from sqlalchemy .engine import Engine , Result
1010from sqlalchemy .exc import MultipleResultsFound , NoSuchColumnError , SQLAlchemyError
1111from sqlalchemy .orm import Session , scoped_session , sessionmaker
1212import sqlparse
3232 CONF_VALUE_TEMPLATE ,
3333)
3434from homeassistant .core import callback
35+ from homeassistant .data_entry_flow import section
3536from homeassistant .helpers import selector
3637
37- from .const import CONF_COLUMN_NAME , CONF_QUERY , DOMAIN
38+ from .const import CONF_ADVANCED_OPTIONS , CONF_COLUMN_NAME , CONF_QUERY , DOMAIN
3839from .util import resolve_db_url
3940
4041_LOGGER = logging .getLogger (__name__ )
4142
4243
4344OPTIONS_SCHEMA : vol .Schema = vol .Schema (
4445 {
45- vol .Optional (
46- CONF_DB_URL ,
47- ): selector .TextSelector (),
48- vol .Required (
49- CONF_COLUMN_NAME ,
50- ): selector .TextSelector (),
51- vol .Required (
52- CONF_QUERY ,
53- ): selector .TextSelector (selector .TextSelectorConfig (multiline = True )),
54- vol .Optional (
55- CONF_UNIT_OF_MEASUREMENT ,
56- ): selector .TextSelector (),
57- vol .Optional (
58- CONF_VALUE_TEMPLATE ,
59- ): selector .TemplateSelector (),
60- vol .Optional (CONF_DEVICE_CLASS ): selector .SelectSelector (
61- selector .SelectSelectorConfig (
62- options = [
63- cls .value
64- for cls in SensorDeviceClass
65- if cls != SensorDeviceClass .ENUM
66- ],
67- mode = selector .SelectSelectorMode .DROPDOWN ,
68- translation_key = "device_class" ,
69- sort = True ,
70- )
46+ vol .Required (CONF_QUERY ): selector .TextSelector (
47+ selector .TextSelectorConfig (multiline = True )
7148 ),
72- vol .Optional (CONF_STATE_CLASS ): selector .SelectSelector (
73- selector .SelectSelectorConfig (
74- options = [cls .value for cls in SensorStateClass ],
75- mode = selector .SelectSelectorMode .DROPDOWN ,
76- translation_key = "state_class" ,
77- sort = True ,
78- )
49+ vol .Required (CONF_COLUMN_NAME ): selector .TextSelector (),
50+ vol .Required (CONF_ADVANCED_OPTIONS ): section (
51+ vol .Schema (
52+ {
53+ vol .Optional (CONF_VALUE_TEMPLATE ): selector .TemplateSelector (),
54+ vol .Optional (CONF_UNIT_OF_MEASUREMENT ): selector .TextSelector (),
55+ vol .Optional (CONF_DEVICE_CLASS ): selector .SelectSelector (
56+ selector .SelectSelectorConfig (
57+ options = [
58+ cls .value
59+ for cls in SensorDeviceClass
60+ if cls != SensorDeviceClass .ENUM
61+ ],
62+ mode = selector .SelectSelectorMode .DROPDOWN ,
63+ translation_key = "device_class" ,
64+ sort = True ,
65+ )
66+ ),
67+ vol .Optional (CONF_STATE_CLASS ): selector .SelectSelector (
68+ selector .SelectSelectorConfig (
69+ options = [cls .value for cls in SensorStateClass ],
70+ mode = selector .SelectSelectorMode .DROPDOWN ,
71+ translation_key = "state_class" ,
72+ sort = True ,
73+ )
74+ ),
75+ }
76+ ),
77+ {"collapsed" : True },
7978 ),
8079 }
8180)
8281
8382CONFIG_SCHEMA : vol .Schema = vol .Schema (
8483 {
8584 vol .Required (CONF_NAME , default = "Select SQL Query" ): selector .TextSelector (),
85+ vol .Optional (CONF_DB_URL ): selector .TextSelector (),
8686 }
87- ). extend ( OPTIONS_SCHEMA . schema )
87+ )
8888
8989
9090def validate_sql_select (value : str ) -> str :
@@ -99,6 +99,31 @@ def validate_sql_select(value: str) -> str:
9999 return str (query [0 ])
100100
101101
102+ def validate_db_connection (db_url : str ) -> bool :
103+ """Validate db connection."""
104+
105+ engine : Engine | None = None
106+ sess : Session | None = None
107+ try :
108+ engine = sqlalchemy .create_engine (db_url , future = True )
109+ sessmaker = scoped_session (sessionmaker (bind = engine , future = True ))
110+ sess = sessmaker ()
111+ sess .execute (sqlalchemy .text ("select 1 as value" ))
112+ except SQLAlchemyError as error :
113+ _LOGGER .debug ("Execution error %s" , error )
114+ if sess :
115+ sess .close ()
116+ if engine :
117+ engine .dispose ()
118+ raise
119+
120+ if sess :
121+ sess .close ()
122+ engine .dispose ()
123+
124+ return True
125+
126+
102127def validate_query (db_url : str , query : str , column : str ) -> bool :
103128 """Validate SQL query."""
104129
@@ -136,7 +161,9 @@ def validate_query(db_url: str, query: str, column: str) -> bool:
136161class SQLConfigFlow (ConfigFlow , domain = DOMAIN ):
137162 """Handle a config flow for SQL integration."""
138163
139- VERSION = 1
164+ VERSION = 2
165+
166+ data : dict [str , Any ]
140167
141168 @staticmethod
142169 @callback
@@ -151,17 +178,46 @@ async def async_step_user(
151178 ) -> ConfigFlowResult :
152179 """Handle the user step."""
153180 errors = {}
154- description_placeholders = {}
155181
156182 if user_input is not None :
157183 db_url = user_input .get (CONF_DB_URL )
184+
185+ try :
186+ db_url_for_validation = resolve_db_url (self .hass , db_url )
187+ await self .hass .async_add_executor_job (
188+ validate_db_connection , db_url_for_validation
189+ )
190+ except SQLAlchemyError :
191+ errors ["db_url" ] = "db_url_invalid"
192+
193+ if not errors :
194+ self .data = {CONF_NAME : user_input [CONF_NAME ]}
195+ if db_url and db_url_for_validation != get_instance (self .hass ).db_url :
196+ self .data [CONF_DB_URL ] = db_url
197+ return await self .async_step_options ()
198+
199+ return self .async_show_form (
200+ step_id = "user" ,
201+ data_schema = self .add_suggested_values_to_schema (CONFIG_SCHEMA , user_input ),
202+ errors = errors ,
203+ )
204+
205+ async def async_step_options (
206+ self , user_input : dict [str , Any ] | None = None
207+ ) -> ConfigFlowResult :
208+ """Handle the user step."""
209+ errors = {}
210+ description_placeholders = {}
211+
212+ if user_input is not None :
158213 query = user_input [CONF_QUERY ]
159214 column = user_input [CONF_COLUMN_NAME ]
160- db_url_for_validation = None
161215
162216 try :
163217 query = validate_sql_select (query )
164- db_url_for_validation = resolve_db_url (self .hass , db_url )
218+ db_url_for_validation = resolve_db_url (
219+ self .hass , self .data .get (CONF_DB_URL )
220+ )
165221 await self .hass .async_add_executor_job (
166222 validate_query , db_url_for_validation , query , column
167223 )
@@ -178,32 +234,25 @@ async def async_step_user(
178234 _LOGGER .debug ("Invalid query: %s" , err )
179235 errors ["query" ] = "query_invalid"
180236
181- options = {
182- CONF_QUERY : query ,
183- CONF_COLUMN_NAME : column ,
184- CONF_NAME : user_input [ CONF_NAME ],
237+ mod_advanced_options = {
238+ k : v
239+ for k , v in user_input [ CONF_ADVANCED_OPTIONS ]. items ()
240+ if v is not None
185241 }
186- if uom := user_input .get (CONF_UNIT_OF_MEASUREMENT ):
187- options [CONF_UNIT_OF_MEASUREMENT ] = uom
188- if value_template := user_input .get (CONF_VALUE_TEMPLATE ):
189- options [CONF_VALUE_TEMPLATE ] = value_template
190- if device_class := user_input .get (CONF_DEVICE_CLASS ):
191- options [CONF_DEVICE_CLASS ] = device_class
192- if state_class := user_input .get (CONF_STATE_CLASS ):
193- options [CONF_STATE_CLASS ] = state_class
194- if db_url_for_validation != get_instance (self .hass ).db_url :
195- options [CONF_DB_URL ] = db_url_for_validation
242+ user_input [CONF_ADVANCED_OPTIONS ] = mod_advanced_options
196243
197244 if not errors :
245+ name = self .data [CONF_NAME ]
246+ self .data .pop (CONF_NAME )
198247 return self .async_create_entry (
199- title = user_input [ CONF_NAME ] ,
200- data = {} ,
201- options = options ,
248+ title = name ,
249+ data = self . data ,
250+ options = user_input ,
202251 )
203252
204253 return self .async_show_form (
205- step_id = "user " ,
206- data_schema = self .add_suggested_values_to_schema (CONFIG_SCHEMA , user_input ),
254+ step_id = "options " ,
255+ data_schema = self .add_suggested_values_to_schema (OPTIONS_SCHEMA , user_input ),
207256 errors = errors ,
208257 description_placeholders = description_placeholders ,
209258 )
@@ -220,10 +269,9 @@ async def async_step_init(
220269 description_placeholders = {}
221270
222271 if user_input is not None :
223- db_url = user_input .get (CONF_DB_URL )
272+ db_url = self . config_entry . data .get (CONF_DB_URL )
224273 query = user_input [CONF_QUERY ]
225274 column = user_input [CONF_COLUMN_NAME ]
226- name = self .config_entry .options .get (CONF_NAME , self .config_entry .title )
227275
228276 try :
229277 query = validate_sql_select (query )
@@ -252,24 +300,15 @@ async def async_step_init(
252300 recorder_db ,
253301 )
254302
255- options = {
256- CONF_QUERY : query ,
257- CONF_COLUMN_NAME : column ,
258- CONF_NAME : name ,
303+ mod_advanced_options = {
304+ k : v
305+ for k , v in user_input [ CONF_ADVANCED_OPTIONS ]. items ()
306+ if v is not None
259307 }
260- if uom := user_input .get (CONF_UNIT_OF_MEASUREMENT ):
261- options [CONF_UNIT_OF_MEASUREMENT ] = uom
262- if value_template := user_input .get (CONF_VALUE_TEMPLATE ):
263- options [CONF_VALUE_TEMPLATE ] = value_template
264- if device_class := user_input .get (CONF_DEVICE_CLASS ):
265- options [CONF_DEVICE_CLASS ] = device_class
266- if state_class := user_input .get (CONF_STATE_CLASS ):
267- options [CONF_STATE_CLASS ] = state_class
268- if db_url_for_validation != get_instance (self .hass ).db_url :
269- options [CONF_DB_URL ] = db_url_for_validation
308+ user_input [CONF_ADVANCED_OPTIONS ] = mod_advanced_options
270309
271310 return self .async_create_entry (
272- data = options ,
311+ data = user_input ,
273312 )
274313
275314 return self .async_show_form (
0 commit comments