diff --git a/README.md b/README.md index bbdfbcc..2c29406 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,10 @@ Enable background data refresh. This will prevent requests from hanging while ne Standard Flask option. Will enabled enhanced logging and wildcard CORS headers. *default: False* +- **SERVICE_ALERTS** +Enable additional service alert information from the MTA's feed. +*default: False* + ## Generating a Stations File The MTA provides several static data files about the subway system but none include canonical information about each station. MTAPI includes a script that will parse the `stops.txt` and `transfers.txt` datasets provided by the MTA and attempt to group the different train stops into subway stations. MTAPI will use this JSON file for station names and locations. The grouping is not perfect and editing the resulting files is encouraged. diff --git a/app.py b/app.py index 740191f..492c173 100644 --- a/app.py +++ b/app.py @@ -22,7 +22,8 @@ MAX_TRAINS=10, MAX_MINUTES=30, CACHE_SECONDS=60, - THREADED=True + THREADED=True, + SERVICE_ALERTS=False ) _SETTINGS_ENV_VAR = 'MTAPI_SETTINGS' @@ -56,7 +57,8 @@ def default(self, obj): max_trains=app.config['MAX_TRAINS'], max_minutes=app.config['MAX_MINUTES'], expires_seconds=app.config['CACHE_SECONDS'], - threaded=app.config['THREADED']) + threaded=app.config['THREADED'], + service_alerts=app.config['SERVICE_ALERTS']) def response_wrapper(f): @wraps(f) diff --git a/mtapi/mtapi.py b/mtapi/mtapi.py index 448a0d2..6abc20b 100644 --- a/mtapi/mtapi.py +++ b/mtapi/mtapi.py @@ -23,6 +23,7 @@ def __init__(self, json): self.json = json self.trains = {} self.clear_train_data() + self.alerts = [] def __getitem__(self, key): return self.json[key] @@ -34,12 +35,21 @@ def add_train(self, route_id, direction, train_time, feed_time): 'time': train_time }) self.last_update = feed_time + + def add_alert(self, alert_type, alert_text): + # Only add alerts once + if not any(a['header_text'] == alert_text for a in self.alerts): + self.alerts.append({ + 'type': alert_type, + 'header_text': alert_text + }) def clear_train_data(self): self.trains['N'] = [] self.trains['S'] = [] self.routes = set() self.last_update = None + self.alerts = [] def sort_trains(self, max_trains): self.trains['S'] = sorted(self.trains['S'], key=itemgetter('time'))[:max_trains] @@ -52,6 +62,8 @@ def serialize(self): 'routes': self.routes, 'last_update': self.last_update } + if self.alerts: + out['service_alerts'] = self.alerts # Only add service alerts if they exist out.update(self.json) return out @@ -67,11 +79,15 @@ def serialize(self): 'https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/nyct%2Fgtfs-g' # G ] - def __init__(self, stations_file, expires_seconds=60, max_trains=10, max_minutes=30, threaded=False): + # GTFS feed for subway service alerts + _SERVICE_ALERT_URL = 'https://api-endpoint.mta.info/Dataservice/mtagtfsfeeds/camsys%2Fsubway-alerts' + + def __init__(self, stations_file, expires_seconds=60, max_trains=10, max_minutes=30, threaded=False, service_alerts=False): self._MAX_TRAINS = max_trains self._MAX_MINUTES = max_minutes self._EXPIRES_SECONDS = expires_seconds self._THREADED = threaded + self._GET_SERVICE_ALERTS = service_alerts self._stations = {} self._stops_to_stations = {} self._routes = {} @@ -114,6 +130,42 @@ def _load_mta_feed(self, feed_url): except (urllib.error.URLError, google.protobuf.message.DecodeError, ConnectionResetError) as e: logger.error('Couldn\'t connect to MTA server: ' + str(e)) return False + + def _is_alert_currently_active(self, entity, current_timestamp: int): + for period in entity.alert.active_period: + # Alert is active if current time is within the period + if period.start <= current_timestamp and (not period.end or period.end >= current_timestamp): + return True + return False + + def _get_alert_text(self, entity, language='en'): + # Fall back to 'en' when language isn't available. + # This is needed for elevator alerts, some text is + # better than no text in this case. + english_text = None + for translation in entity.alert.header_text.translation: + if translation.language == language: + return translation.text + elif translation.language == 'en': + english_text = translation.text + return english_text + + def _get_station_routes(self, station): + return { + train['route'] + for direction in station.trains.values() + for train in direction + } + + def _alert_applies_to_stop(self, informed, station_stops): + if not informed.HasField('stop_id'): + return False + + # Check both with and without direction suffix (N/S) + alert_stop = informed.stop_id + alert_stop_base = alert_stop.rstrip('NS') + + return any(stop in [alert_stop, alert_stop_base] for stop in station_stops) def _update(self): logger.info('updating...') @@ -128,6 +180,20 @@ def _update(self): routes = defaultdict(set) + # Get service alerts + service_alerts = [] + if self._GET_SERVICE_ALERTS: + logger.info('fetching service alerts...') + service_alerts_feed = self._load_mta_feed(self._SERVICE_ALERT_URL) + + if service_alerts_feed: + # Filter for active alerts + current_timestamp = int(self._last_update.timestamp()) + service_alerts = [ + entity for entity in service_alerts_feed.entity + if self._is_alert_currently_active(entity, current_timestamp) + ] + for i, feed_url in enumerate(self._FEED_URLS): mta_data = self._load_mta_feed(feed_url) @@ -170,6 +236,27 @@ def _update(self): for id in stations: stations[id].sort_trains(self._MAX_TRAINS) + # Add service alerts to stations + if self._GET_SERVICE_ALERTS and service_alerts: + for station_id, station in stations.items(): + station_routes = self._get_station_routes(station) + + for alert_entity in service_alerts: + alert_text = self._get_alert_text(alert_entity, 'en-html') + if not alert_text: + continue + + for informed in alert_entity.alert.informed_entity: + # Check if alert applies to this specific stop + if self._alert_applies_to_stop(informed, station['stops']): + station.add_alert('stop', alert_text) + break # Only add once per alert per station + + # Check if alert applies to a route serving the station + if informed.HasField('route_id') and informed.route_id in station_routes: + station.add_alert('route', alert_text) + break # Only add once per alert per station + with self._read_lock: self._routes = routes self._stations = stations diff --git a/settings.cfg.sample b/settings.cfg.sample index 3efb7e0..2939d52 100644 --- a/settings.cfg.sample +++ b/settings.cfg.sample @@ -1,7 +1,8 @@ STATIONS_FILE = './data/stations.json' CROSS_ORIGIN = 'http://yourdomain.com' -MAX_TRAINS=10 -MAX_MINUTES=30 -CACHE_SECONDS=60 -THREADED=True +MAX_TRAINS = 10 +MAX_MINUTES = 30 +CACHE_SECONDS = 60 +THREADED = True DEBUG = True +SERVICE_ALERTS = False