|
| 1 | +"""Code to generate common control usage patterns.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from collections import Counter |
| 6 | +from collections.abc import Callable |
| 7 | +from datetime import datetime, timedelta |
| 8 | +from functools import cache |
| 9 | +import logging |
| 10 | +from typing import Any, Literal, cast |
| 11 | + |
| 12 | +from sqlalchemy import select |
| 13 | +from sqlalchemy.orm import Session |
| 14 | + |
| 15 | +from homeassistant.components.recorder import get_instance |
| 16 | +from homeassistant.components.recorder.db_schema import EventData, Events, EventTypes |
| 17 | +from homeassistant.components.recorder.models import uuid_hex_to_bytes_or_none |
| 18 | +from homeassistant.components.recorder.util import session_scope |
| 19 | +from homeassistant.const import Platform |
| 20 | +from homeassistant.core import HomeAssistant |
| 21 | +from homeassistant.util import dt as dt_util |
| 22 | +from homeassistant.util.json import json_loads_object |
| 23 | + |
| 24 | +from .models import EntityUsagePredictions |
| 25 | + |
| 26 | +_LOGGER = logging.getLogger(__name__) |
| 27 | + |
| 28 | +# Time categories for usage patterns |
| 29 | +TIME_CATEGORIES = ["morning", "afternoon", "evening", "night"] |
| 30 | + |
| 31 | +RESULTS_TO_INCLUDE = 8 |
| 32 | + |
| 33 | +# List of domains for which we want to track usage |
| 34 | +ALLOWED_DOMAINS = { |
| 35 | + # Entity platforms |
| 36 | + Platform.AIR_QUALITY, |
| 37 | + Platform.ALARM_CONTROL_PANEL, |
| 38 | + Platform.BINARY_SENSOR, |
| 39 | + Platform.BUTTON, |
| 40 | + Platform.CALENDAR, |
| 41 | + Platform.CAMERA, |
| 42 | + Platform.CLIMATE, |
| 43 | + Platform.COVER, |
| 44 | + Platform.DATE, |
| 45 | + Platform.DATETIME, |
| 46 | + Platform.FAN, |
| 47 | + Platform.HUMIDIFIER, |
| 48 | + Platform.IMAGE, |
| 49 | + Platform.LAWN_MOWER, |
| 50 | + Platform.LIGHT, |
| 51 | + Platform.LOCK, |
| 52 | + Platform.MEDIA_PLAYER, |
| 53 | + Platform.NUMBER, |
| 54 | + Platform.SCENE, |
| 55 | + Platform.SELECT, |
| 56 | + Platform.SENSOR, |
| 57 | + Platform.SIREN, |
| 58 | + Platform.SWITCH, |
| 59 | + Platform.TEXT, |
| 60 | + Platform.TIME, |
| 61 | + Platform.TODO, |
| 62 | + Platform.UPDATE, |
| 63 | + Platform.VACUUM, |
| 64 | + Platform.VALVE, |
| 65 | + Platform.WAKE_WORD, |
| 66 | + Platform.WATER_HEATER, |
| 67 | + Platform.WEATHER, |
| 68 | + # Helpers with own domain |
| 69 | + "counter", |
| 70 | + "group", |
| 71 | + "input_boolean", |
| 72 | + "input_button", |
| 73 | + "input_datetime", |
| 74 | + "input_number", |
| 75 | + "input_select", |
| 76 | + "input_text", |
| 77 | + "schedule", |
| 78 | + "timer", |
| 79 | +} |
| 80 | + |
| 81 | + |
| 82 | +@cache |
| 83 | +def time_category(hour: int) -> Literal["morning", "afternoon", "evening", "night"]: |
| 84 | + """Determine the time category for a given hour.""" |
| 85 | + if 6 <= hour < 12: |
| 86 | + return "morning" |
| 87 | + if 12 <= hour < 18: |
| 88 | + return "afternoon" |
| 89 | + if 18 <= hour < 22: |
| 90 | + return "evening" |
| 91 | + return "night" |
| 92 | + |
| 93 | + |
| 94 | +async def async_predict_common_control( |
| 95 | + hass: HomeAssistant, user_id: str |
| 96 | +) -> EntityUsagePredictions: |
| 97 | + """Generate a list of commonly used entities for a user. |
| 98 | +
|
| 99 | + Args: |
| 100 | + hass: Home Assistant instance |
| 101 | + user_id: User ID to filter events by. |
| 102 | +
|
| 103 | + Returns: |
| 104 | + Dictionary with time categories as keys and lists of most common entity IDs as values |
| 105 | + """ |
| 106 | + # Get the recorder instance to ensure it's ready |
| 107 | + recorder = get_instance(hass) |
| 108 | + |
| 109 | + # Execute the database operation in the recorder's executor |
| 110 | + return await recorder.async_add_executor_job( |
| 111 | + _fetch_with_session, hass, _fetch_and_process_data, user_id |
| 112 | + ) |
| 113 | + |
| 114 | + |
| 115 | +def _fetch_and_process_data(session: Session, user_id: str) -> EntityUsagePredictions: |
| 116 | + """Fetch and process service call events from the database.""" |
| 117 | + # Prepare a dictionary to track results |
| 118 | + results: dict[str, Counter[str]] = { |
| 119 | + time_cat: Counter() for time_cat in TIME_CATEGORIES |
| 120 | + } |
| 121 | + |
| 122 | + # Keep track of contexts that we processed so that we will only process |
| 123 | + # the first service call in a context, and not subsequent calls. |
| 124 | + context_processed: set[bytes] = set() |
| 125 | + thirty_days_ago_ts = (dt_util.utcnow() - timedelta(days=30)).timestamp() |
| 126 | + user_id_bytes = uuid_hex_to_bytes_or_none(user_id) |
| 127 | + if not user_id_bytes: |
| 128 | + raise ValueError("Invalid user_id format") |
| 129 | + |
| 130 | + # Build the main query for events with their data |
| 131 | + query = ( |
| 132 | + select( |
| 133 | + Events.context_id_bin, |
| 134 | + Events.time_fired_ts, |
| 135 | + EventData.shared_data, |
| 136 | + ) |
| 137 | + .select_from(Events) |
| 138 | + .outerjoin(EventData, Events.data_id == EventData.data_id) |
| 139 | + .outerjoin(EventTypes, Events.event_type_id == EventTypes.event_type_id) |
| 140 | + .where(Events.time_fired_ts >= thirty_days_ago_ts) |
| 141 | + .where(Events.context_user_id_bin == user_id_bytes) |
| 142 | + .where(EventTypes.event_type == "call_service") |
| 143 | + .order_by(Events.time_fired_ts) |
| 144 | + ) |
| 145 | + |
| 146 | + # Execute the query |
| 147 | + context_id: bytes |
| 148 | + time_fired_ts: float |
| 149 | + shared_data: str | None |
| 150 | + local_time_zone = dt_util.get_default_time_zone() |
| 151 | + for context_id, time_fired_ts, shared_data in ( |
| 152 | + session.connection().execute(query).all() |
| 153 | + ): |
| 154 | + # Skip if we have already processed an event that was part of this context |
| 155 | + if context_id in context_processed: |
| 156 | + continue |
| 157 | + |
| 158 | + # Mark this context as processed |
| 159 | + context_processed.add(context_id) |
| 160 | + |
| 161 | + # Parse the event data |
| 162 | + if not shared_data: |
| 163 | + continue |
| 164 | + |
| 165 | + try: |
| 166 | + event_data = json_loads_object(shared_data) |
| 167 | + except (ValueError, TypeError) as err: |
| 168 | + _LOGGER.debug("Failed to parse event data: %s", err) |
| 169 | + continue |
| 170 | + |
| 171 | + # Empty event data, skipping |
| 172 | + if not event_data: |
| 173 | + continue |
| 174 | + |
| 175 | + service_data = cast(dict[str, Any] | None, event_data.get("service_data")) |
| 176 | + |
| 177 | + # No service data found, skipping |
| 178 | + if not service_data: |
| 179 | + continue |
| 180 | + |
| 181 | + entity_ids: str | list[str] | None |
| 182 | + if (target := service_data.get("target")) and ( |
| 183 | + target_entity_ids := target.get("entity_id") |
| 184 | + ): |
| 185 | + entity_ids = target_entity_ids |
| 186 | + else: |
| 187 | + entity_ids = service_data.get("entity_id") |
| 188 | + |
| 189 | + # No entity IDs found, skip this event |
| 190 | + if entity_ids is None: |
| 191 | + continue |
| 192 | + |
| 193 | + if not isinstance(entity_ids, list): |
| 194 | + entity_ids = [entity_ids] |
| 195 | + |
| 196 | + # Filter out entity IDs that are not in allowed domains |
| 197 | + entity_ids = [ |
| 198 | + entity_id |
| 199 | + for entity_id in entity_ids |
| 200 | + if entity_id.split(".")[0] in ALLOWED_DOMAINS |
| 201 | + ] |
| 202 | + |
| 203 | + if not entity_ids: |
| 204 | + continue |
| 205 | + |
| 206 | + # Convert timestamp to datetime and determine time category |
| 207 | + if time_fired_ts: |
| 208 | + # Convert to local time for time category determination |
| 209 | + period = time_category( |
| 210 | + datetime.fromtimestamp(time_fired_ts, local_time_zone).hour |
| 211 | + ) |
| 212 | + |
| 213 | + # Count entity usage |
| 214 | + for entity_id in entity_ids: |
| 215 | + results[period][entity_id] += 1 |
| 216 | + |
| 217 | + return EntityUsagePredictions( |
| 218 | + morning=[ |
| 219 | + ent_id for (ent_id, _) in results["morning"].most_common(RESULTS_TO_INCLUDE) |
| 220 | + ], |
| 221 | + afternoon=[ |
| 222 | + ent_id |
| 223 | + for (ent_id, _) in results["afternoon"].most_common(RESULTS_TO_INCLUDE) |
| 224 | + ], |
| 225 | + evening=[ |
| 226 | + ent_id for (ent_id, _) in results["evening"].most_common(RESULTS_TO_INCLUDE) |
| 227 | + ], |
| 228 | + night=[ |
| 229 | + ent_id for (ent_id, _) in results["night"].most_common(RESULTS_TO_INCLUDE) |
| 230 | + ], |
| 231 | + ) |
| 232 | + |
| 233 | + |
| 234 | +def _fetch_with_session( |
| 235 | + hass: HomeAssistant, |
| 236 | + fetch_func: Callable[[Session], EntityUsagePredictions], |
| 237 | + *args: object, |
| 238 | +) -> EntityUsagePredictions: |
| 239 | + """Execute a fetch function with a database session.""" |
| 240 | + with session_scope(hass=hass, read_only=True) as session: |
| 241 | + return fetch_func(session, *args) |
0 commit comments