|
| 1 | +"""DataUpdateCoordinator for Nederlandse Spoorwegen.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from dataclasses import dataclass |
| 6 | +from datetime import datetime |
| 7 | +import logging |
| 8 | + |
| 9 | +from ns_api import NSAPI, Trip |
| 10 | +from requests.exceptions import ConnectionError, HTTPError, Timeout |
| 11 | + |
| 12 | +from homeassistant.config_entries import ConfigEntry, ConfigSubentry |
| 13 | +from homeassistant.const import CONF_API_KEY, CONF_NAME |
| 14 | +from homeassistant.core import HomeAssistant |
| 15 | +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed |
| 16 | +from homeassistant.util import dt as dt_util |
| 17 | + |
| 18 | +from .const import ( |
| 19 | + AMS_TZ, |
| 20 | + CONF_FROM, |
| 21 | + CONF_TIME, |
| 22 | + CONF_TO, |
| 23 | + CONF_VIA, |
| 24 | + DOMAIN, |
| 25 | + SCAN_INTERVAL, |
| 26 | +) |
| 27 | + |
| 28 | +_LOGGER = logging.getLogger(__name__) |
| 29 | + |
| 30 | + |
| 31 | +def _now_nl() -> datetime: |
| 32 | + """Return current time in Europe/Amsterdam timezone.""" |
| 33 | + return dt_util.now(AMS_TZ) |
| 34 | + |
| 35 | + |
| 36 | +type NSConfigEntry = ConfigEntry[dict[str, NSDataUpdateCoordinator]] |
| 37 | + |
| 38 | + |
| 39 | +@dataclass |
| 40 | +class NSRouteResult: |
| 41 | + """Data class for Nederlandse Spoorwegen API results.""" |
| 42 | + |
| 43 | + trips: list[Trip] |
| 44 | + first_trip: Trip | None = None |
| 45 | + next_trip: Trip | None = None |
| 46 | + |
| 47 | + |
| 48 | +class NSDataUpdateCoordinator(DataUpdateCoordinator[NSRouteResult]): |
| 49 | + """Class to manage fetching Nederlandse Spoorwegen data from the API for a single route.""" |
| 50 | + |
| 51 | + def __init__( |
| 52 | + self, |
| 53 | + hass: HomeAssistant, |
| 54 | + config_entry: NSConfigEntry, |
| 55 | + route_id: str, |
| 56 | + subentry: ConfigSubentry, |
| 57 | + ) -> None: |
| 58 | + """Initialize the coordinator for a specific route.""" |
| 59 | + super().__init__( |
| 60 | + hass, |
| 61 | + _LOGGER, |
| 62 | + name=f"{DOMAIN}_{route_id}", |
| 63 | + update_interval=SCAN_INTERVAL, |
| 64 | + config_entry=config_entry, |
| 65 | + ) |
| 66 | + self.id = route_id |
| 67 | + self.nsapi = NSAPI(config_entry.data[CONF_API_KEY]) |
| 68 | + self.name = subentry.data[CONF_NAME] |
| 69 | + self.departure = subentry.data[CONF_FROM] |
| 70 | + self.destination = subentry.data[CONF_TO] |
| 71 | + self.via = subentry.data.get(CONF_VIA) |
| 72 | + self.departure_time = subentry.data.get(CONF_TIME) # str | None |
| 73 | + |
| 74 | + async def _async_update_data(self) -> NSRouteResult: |
| 75 | + """Fetch data from NS API for this specific route.""" |
| 76 | + trips: list[Trip] = [] |
| 77 | + first_trip: Trip | None = None |
| 78 | + next_trip: Trip | None = None |
| 79 | + try: |
| 80 | + trips = await self._get_trips( |
| 81 | + self.departure, |
| 82 | + self.destination, |
| 83 | + self.via, |
| 84 | + departure_time=self.departure_time, |
| 85 | + ) |
| 86 | + |
| 87 | + except (ConnectionError, Timeout, HTTPError, ValueError) as err: |
| 88 | + # Surface API failures to Home Assistant so the entities become unavailable |
| 89 | + raise UpdateFailed(f"API communication error: {err}") from err |
| 90 | + |
| 91 | + # Filter out trips that have already departed (trips are already sorted) |
| 92 | + future_trips = self._remove_trips_in_the_past(trips) |
| 93 | + |
| 94 | + # Process trips to find current and next departure |
| 95 | + first_trip, next_trip = self._get_first_and_next_trips(future_trips) |
| 96 | + |
| 97 | + return NSRouteResult( |
| 98 | + trips=trips, |
| 99 | + first_trip=first_trip, |
| 100 | + next_trip=next_trip, |
| 101 | + ) |
| 102 | + |
| 103 | + def _get_time_from_route(self, time_str: str | None) -> str: |
| 104 | + """Combine today's date with a time string if needed.""" |
| 105 | + if not time_str: |
| 106 | + return _now_nl().strftime("%d-%m-%Y %H:%M") |
| 107 | + |
| 108 | + if ( |
| 109 | + isinstance(time_str, str) |
| 110 | + and len(time_str.split(":")) in (2, 3) |
| 111 | + and " " not in time_str |
| 112 | + ): |
| 113 | + today = _now_nl().strftime("%d-%m-%Y") |
| 114 | + return f"{today} {time_str[:5]}" |
| 115 | + # Fallback: use current date and time |
| 116 | + return _now_nl().strftime("%d-%m-%Y %H:%M") |
| 117 | + |
| 118 | + async def _get_trips( |
| 119 | + self, |
| 120 | + departure: str, |
| 121 | + destination: str, |
| 122 | + via: str | None = None, |
| 123 | + departure_time: str | None = None, |
| 124 | + ) -> list[Trip]: |
| 125 | + """Get trips from NS API, sorted by departure time.""" |
| 126 | + |
| 127 | + # Convert time to full date-time string if needed and default to Dutch local time if not provided |
| 128 | + time_str = self._get_time_from_route(departure_time) |
| 129 | + |
| 130 | + trips = await self.hass.async_add_executor_job( |
| 131 | + self.nsapi.get_trips, |
| 132 | + time_str, # trip_time |
| 133 | + departure, # departure |
| 134 | + via, # via |
| 135 | + destination, # destination |
| 136 | + True, # exclude_high_speed |
| 137 | + 0, # year_card |
| 138 | + 2, # max_number_of_transfers |
| 139 | + ) |
| 140 | + |
| 141 | + if not trips: |
| 142 | + return [] |
| 143 | + |
| 144 | + return sorted( |
| 145 | + trips, |
| 146 | + key=lambda trip: ( |
| 147 | + trip.departure_time_actual |
| 148 | + if trip.departure_time_actual is not None |
| 149 | + else trip.departure_time_planned |
| 150 | + if trip.departure_time_planned is not None |
| 151 | + else _now_nl() |
| 152 | + ), |
| 153 | + ) |
| 154 | + |
| 155 | + def _get_first_and_next_trips( |
| 156 | + self, trips: list[Trip] |
| 157 | + ) -> tuple[Trip | None, Trip | None]: |
| 158 | + """Process trips to find the first and next departure.""" |
| 159 | + if not trips: |
| 160 | + return None, None |
| 161 | + |
| 162 | + # First trip is the earliest future trip |
| 163 | + first_trip = trips[0] |
| 164 | + |
| 165 | + # Find next trip with different departure time |
| 166 | + next_trip = self._find_next_trip(trips, first_trip) |
| 167 | + |
| 168 | + return first_trip, next_trip |
| 169 | + |
| 170 | + def _remove_trips_in_the_past(self, trips: list[Trip]) -> list[Trip]: |
| 171 | + """Filter out trips that have already departed.""" |
| 172 | + # Compare against Dutch local time to align with ns_api timezone handling |
| 173 | + now = _now_nl() |
| 174 | + future_trips = [] |
| 175 | + for trip in trips: |
| 176 | + departure_time = ( |
| 177 | + trip.departure_time_actual |
| 178 | + if trip.departure_time_actual is not None |
| 179 | + else trip.departure_time_planned |
| 180 | + ) |
| 181 | + if departure_time is not None and ( |
| 182 | + departure_time.tzinfo is None |
| 183 | + or departure_time.tzinfo.utcoffset(departure_time) is None |
| 184 | + ): |
| 185 | + # Make naive datetimes timezone-aware using current reference tz |
| 186 | + departure_time = departure_time.replace(tzinfo=now.tzinfo) |
| 187 | + |
| 188 | + if departure_time and departure_time > now: |
| 189 | + future_trips.append(trip) |
| 190 | + return future_trips |
| 191 | + |
| 192 | + def _find_next_trip( |
| 193 | + self, future_trips: list[Trip], first_trip: Trip |
| 194 | + ) -> Trip | None: |
| 195 | + """Find the next trip with a different departure time than the first trip.""" |
| 196 | + next_trip = None |
| 197 | + if len(future_trips) > 1: |
| 198 | + first_time = ( |
| 199 | + first_trip.departure_time_actual |
| 200 | + if first_trip.departure_time_actual is not None |
| 201 | + else first_trip.departure_time_planned |
| 202 | + ) |
| 203 | + for trip in future_trips[1:]: |
| 204 | + trip_time = ( |
| 205 | + trip.departure_time_actual |
| 206 | + if trip.departure_time_actual is not None |
| 207 | + else trip.departure_time_planned |
| 208 | + ) |
| 209 | + if trip_time and first_time and trip_time > first_time: |
| 210 | + next_trip = trip |
| 211 | + break |
| 212 | + return next_trip |
0 commit comments