Skip to content

Commit d65e704

Browse files
balloobbdraco
andauthored
Add usage_prediction integration (#151206)
Co-authored-by: J. Nick Koston <[email protected]> Co-authored-by: J. Nick Koston <[email protected]>
1 parent aadaf87 commit d65e704

File tree

13 files changed

+929
-0
lines changed

13 files changed

+929
-0
lines changed

CODEOWNERS

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

homeassistant/components/default_config/manifest.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"ssdp",
2020
"stream",
2121
"sun",
22+
"usage_prediction",
2223
"usb",
2324
"webhook",
2425
"zeroconf"
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""The usage prediction integration."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
from datetime import timedelta
7+
from typing import Any
8+
9+
from homeassistant.components import websocket_api
10+
from homeassistant.core import HomeAssistant
11+
from homeassistant.helpers import config_validation as cv
12+
from homeassistant.helpers.typing import ConfigType
13+
from homeassistant.util import dt as dt_util
14+
15+
from . import common_control
16+
from .const import DATA_CACHE, DOMAIN
17+
from .models import EntityUsageDataCache, EntityUsagePredictions
18+
19+
CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
20+
21+
CACHE_DURATION = timedelta(hours=24)
22+
23+
24+
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
25+
"""Set up the usage prediction integration."""
26+
websocket_api.async_register_command(hass, ws_common_control)
27+
hass.data[DATA_CACHE] = {}
28+
return True
29+
30+
31+
@websocket_api.websocket_command({"type": f"{DOMAIN}/common_control"})
32+
@websocket_api.async_response
33+
async def ws_common_control(
34+
hass: HomeAssistant,
35+
connection: websocket_api.ActiveConnection,
36+
msg: dict[str, Any],
37+
) -> None:
38+
"""Handle usage prediction common control WebSocket API."""
39+
result = await get_cached_common_control(hass, connection.user.id)
40+
time_category = common_control.time_category(dt_util.now().hour)
41+
connection.send_result(
42+
msg["id"],
43+
{
44+
"entities": getattr(result, time_category),
45+
},
46+
)
47+
48+
49+
async def get_cached_common_control(
50+
hass: HomeAssistant, user_id: str
51+
) -> EntityUsagePredictions:
52+
"""Get cached common control predictions or fetch new ones.
53+
54+
Returns cached data if it's less than 24 hours old,
55+
otherwise fetches new data and caches it.
56+
"""
57+
# Create a unique storage key for this user
58+
storage_key = user_id
59+
60+
cached_data = hass.data[DATA_CACHE].get(storage_key)
61+
62+
if isinstance(cached_data, asyncio.Task):
63+
# If there's an ongoing task to fetch data, await its result
64+
return await cached_data
65+
66+
# Check if cache is valid (less than 24 hours old)
67+
if cached_data is not None:
68+
if (dt_util.utcnow() - cached_data.timestamp) < CACHE_DURATION:
69+
# Cache is still valid, return the cached predictions
70+
return cached_data.predictions
71+
72+
# Create task fetching data
73+
task = hass.async_create_task(
74+
common_control.async_predict_common_control(hass, user_id)
75+
)
76+
hass.data[DATA_CACHE][storage_key] = task
77+
78+
try:
79+
predictions = await task
80+
except Exception:
81+
# If the task fails, remove it from cache to allow retries
82+
hass.data[DATA_CACHE].pop(storage_key)
83+
raise
84+
85+
hass.data[DATA_CACHE][storage_key] = EntityUsageDataCache(
86+
predictions=predictions,
87+
)
88+
89+
return predictions
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
"""Code to generate common control usage patterns."""
2+
3+
from __future__ import annotations
4+
5+
from collections import Counter
6+
from collections.abc import Callable
7+
from datetime import datetime, timedelta
8+
from functools import cache
9+
import logging
10+
from typing import Any, Literal, cast
11+
12+
from sqlalchemy import select
13+
from sqlalchemy.orm import Session
14+
15+
from homeassistant.components.recorder import get_instance
16+
from homeassistant.components.recorder.db_schema import EventData, Events, EventTypes
17+
from homeassistant.components.recorder.models import uuid_hex_to_bytes_or_none
18+
from homeassistant.components.recorder.util import session_scope
19+
from homeassistant.const import Platform
20+
from homeassistant.core import HomeAssistant
21+
from homeassistant.util import dt as dt_util
22+
from homeassistant.util.json import json_loads_object
23+
24+
from .models import EntityUsagePredictions
25+
26+
_LOGGER = logging.getLogger(__name__)
27+
28+
# Time categories for usage patterns
29+
TIME_CATEGORIES = ["morning", "afternoon", "evening", "night"]
30+
31+
RESULTS_TO_INCLUDE = 8
32+
33+
# List of domains for which we want to track usage
34+
ALLOWED_DOMAINS = {
35+
# Entity platforms
36+
Platform.AIR_QUALITY,
37+
Platform.ALARM_CONTROL_PANEL,
38+
Platform.BINARY_SENSOR,
39+
Platform.BUTTON,
40+
Platform.CALENDAR,
41+
Platform.CAMERA,
42+
Platform.CLIMATE,
43+
Platform.COVER,
44+
Platform.DATE,
45+
Platform.DATETIME,
46+
Platform.FAN,
47+
Platform.HUMIDIFIER,
48+
Platform.IMAGE,
49+
Platform.LAWN_MOWER,
50+
Platform.LIGHT,
51+
Platform.LOCK,
52+
Platform.MEDIA_PLAYER,
53+
Platform.NUMBER,
54+
Platform.SCENE,
55+
Platform.SELECT,
56+
Platform.SENSOR,
57+
Platform.SIREN,
58+
Platform.SWITCH,
59+
Platform.TEXT,
60+
Platform.TIME,
61+
Platform.TODO,
62+
Platform.UPDATE,
63+
Platform.VACUUM,
64+
Platform.VALVE,
65+
Platform.WAKE_WORD,
66+
Platform.WATER_HEATER,
67+
Platform.WEATHER,
68+
# Helpers with own domain
69+
"counter",
70+
"group",
71+
"input_boolean",
72+
"input_button",
73+
"input_datetime",
74+
"input_number",
75+
"input_select",
76+
"input_text",
77+
"schedule",
78+
"timer",
79+
}
80+
81+
82+
@cache
83+
def time_category(hour: int) -> Literal["morning", "afternoon", "evening", "night"]:
84+
"""Determine the time category for a given hour."""
85+
if 6 <= hour < 12:
86+
return "morning"
87+
if 12 <= hour < 18:
88+
return "afternoon"
89+
if 18 <= hour < 22:
90+
return "evening"
91+
return "night"
92+
93+
94+
async def async_predict_common_control(
95+
hass: HomeAssistant, user_id: str
96+
) -> EntityUsagePredictions:
97+
"""Generate a list of commonly used entities for a user.
98+
99+
Args:
100+
hass: Home Assistant instance
101+
user_id: User ID to filter events by.
102+
103+
Returns:
104+
Dictionary with time categories as keys and lists of most common entity IDs as values
105+
"""
106+
# Get the recorder instance to ensure it's ready
107+
recorder = get_instance(hass)
108+
109+
# Execute the database operation in the recorder's executor
110+
return await recorder.async_add_executor_job(
111+
_fetch_with_session, hass, _fetch_and_process_data, user_id
112+
)
113+
114+
115+
def _fetch_and_process_data(session: Session, user_id: str) -> EntityUsagePredictions:
116+
"""Fetch and process service call events from the database."""
117+
# Prepare a dictionary to track results
118+
results: dict[str, Counter[str]] = {
119+
time_cat: Counter() for time_cat in TIME_CATEGORIES
120+
}
121+
122+
# Keep track of contexts that we processed so that we will only process
123+
# the first service call in a context, and not subsequent calls.
124+
context_processed: set[bytes] = set()
125+
thirty_days_ago_ts = (dt_util.utcnow() - timedelta(days=30)).timestamp()
126+
user_id_bytes = uuid_hex_to_bytes_or_none(user_id)
127+
if not user_id_bytes:
128+
raise ValueError("Invalid user_id format")
129+
130+
# Build the main query for events with their data
131+
query = (
132+
select(
133+
Events.context_id_bin,
134+
Events.time_fired_ts,
135+
EventData.shared_data,
136+
)
137+
.select_from(Events)
138+
.outerjoin(EventData, Events.data_id == EventData.data_id)
139+
.outerjoin(EventTypes, Events.event_type_id == EventTypes.event_type_id)
140+
.where(Events.time_fired_ts >= thirty_days_ago_ts)
141+
.where(Events.context_user_id_bin == user_id_bytes)
142+
.where(EventTypes.event_type == "call_service")
143+
.order_by(Events.time_fired_ts)
144+
)
145+
146+
# Execute the query
147+
context_id: bytes
148+
time_fired_ts: float
149+
shared_data: str | None
150+
local_time_zone = dt_util.get_default_time_zone()
151+
for context_id, time_fired_ts, shared_data in (
152+
session.connection().execute(query).all()
153+
):
154+
# Skip if we have already processed an event that was part of this context
155+
if context_id in context_processed:
156+
continue
157+
158+
# Mark this context as processed
159+
context_processed.add(context_id)
160+
161+
# Parse the event data
162+
if not shared_data:
163+
continue
164+
165+
try:
166+
event_data = json_loads_object(shared_data)
167+
except (ValueError, TypeError) as err:
168+
_LOGGER.debug("Failed to parse event data: %s", err)
169+
continue
170+
171+
# Empty event data, skipping
172+
if not event_data:
173+
continue
174+
175+
service_data = cast(dict[str, Any] | None, event_data.get("service_data"))
176+
177+
# No service data found, skipping
178+
if not service_data:
179+
continue
180+
181+
entity_ids: str | list[str] | None
182+
if (target := service_data.get("target")) and (
183+
target_entity_ids := target.get("entity_id")
184+
):
185+
entity_ids = target_entity_ids
186+
else:
187+
entity_ids = service_data.get("entity_id")
188+
189+
# No entity IDs found, skip this event
190+
if entity_ids is None:
191+
continue
192+
193+
if not isinstance(entity_ids, list):
194+
entity_ids = [entity_ids]
195+
196+
# Filter out entity IDs that are not in allowed domains
197+
entity_ids = [
198+
entity_id
199+
for entity_id in entity_ids
200+
if entity_id.split(".")[0] in ALLOWED_DOMAINS
201+
]
202+
203+
if not entity_ids:
204+
continue
205+
206+
# Convert timestamp to datetime and determine time category
207+
if time_fired_ts:
208+
# Convert to local time for time category determination
209+
period = time_category(
210+
datetime.fromtimestamp(time_fired_ts, local_time_zone).hour
211+
)
212+
213+
# Count entity usage
214+
for entity_id in entity_ids:
215+
results[period][entity_id] += 1
216+
217+
return EntityUsagePredictions(
218+
morning=[
219+
ent_id for (ent_id, _) in results["morning"].most_common(RESULTS_TO_INCLUDE)
220+
],
221+
afternoon=[
222+
ent_id
223+
for (ent_id, _) in results["afternoon"].most_common(RESULTS_TO_INCLUDE)
224+
],
225+
evening=[
226+
ent_id for (ent_id, _) in results["evening"].most_common(RESULTS_TO_INCLUDE)
227+
],
228+
night=[
229+
ent_id for (ent_id, _) in results["night"].most_common(RESULTS_TO_INCLUDE)
230+
],
231+
)
232+
233+
234+
def _fetch_with_session(
235+
hass: HomeAssistant,
236+
fetch_func: Callable[[Session], EntityUsagePredictions],
237+
*args: object,
238+
) -> EntityUsagePredictions:
239+
"""Execute a fetch function with a database session."""
240+
with session_scope(hass=hass, read_only=True) as session:
241+
return fetch_func(session, *args)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Constants for the usage prediction integration."""
2+
3+
import asyncio
4+
5+
from homeassistant.util.hass_dict import HassKey
6+
7+
from .models import EntityUsageDataCache, EntityUsagePredictions
8+
9+
DOMAIN = "usage_prediction"
10+
11+
DATA_CACHE: HassKey[
12+
dict[str, asyncio.Task[EntityUsagePredictions] | EntityUsageDataCache]
13+
] = HassKey("usage_prediction")
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"domain": "usage_prediction",
3+
"name": "Usage Prediction",
4+
"codeowners": ["@home-assistant/core"],
5+
"dependencies": ["http", "recorder"],
6+
"documentation": "https://www.home-assistant.io/integrations/usage_prediction",
7+
"integration_type": "system",
8+
"iot_class": "calculated",
9+
"quality_scale": "internal"
10+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Models for the usage prediction integration."""
2+
3+
from dataclasses import dataclass, field
4+
from datetime import datetime
5+
6+
from homeassistant.util import dt as dt_util
7+
8+
9+
@dataclass
10+
class EntityUsagePredictions:
11+
"""Prediction which entities are likely to be used in each time category."""
12+
13+
morning: list[str] = field(default_factory=list)
14+
afternoon: list[str] = field(default_factory=list)
15+
evening: list[str] = field(default_factory=list)
16+
night: list[str] = field(default_factory=list)
17+
18+
19+
@dataclass
20+
class EntityUsageDataCache:
21+
"""Data model for entity usage prediction."""
22+
23+
predictions: EntityUsagePredictions
24+
timestamp: datetime = field(default_factory=dt_util.utcnow)

0 commit comments

Comments
 (0)