Skip to content

Commit 819d817

Browse files
Merge pull request #16 from JarvusInnovations/themightychris/multi-system-download
Add multi-system agency support to download script
2 parents e7b6c64 + 1ebefef commit 819d817

File tree

3 files changed

+144
-30
lines changed

3 files changed

+144
-30
lines changed

docs/downloading_data.md

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ This shows all agencies with their feed counts and available date ranges.
4949
Download all feeds for a specific agency:
5050

5151
```bash
52-
# Download SEPTA data
52+
# Download SEPTA data (all systems)
5353
uv run python scripts/download_data.py --agency septa --date 2026-01-20
5454

5555
# Download AC Transit (same as --defaults)
@@ -58,6 +58,34 @@ uv run python scripts/download_data.py --agency actransit
5858

5959
The script shows estimated download sizes before downloading.
6060

61+
## Multi-System Agencies
62+
63+
Some agencies like SEPTA have multiple transit systems (bus, rail) with separate feeds.
64+
65+
### View system breakdown
66+
67+
```bash
68+
uv run python scripts/download_data.py --list
69+
```
70+
71+
Output shows systems for multi-system agencies:
72+
73+
```
74+
septa SEPTA 6 2026-01-01 to 2026-01-25
75+
└─ septa/bus Bus 3 2026-01-01 to 2026-01-25
76+
└─ septa/rail Regional Rail 3 2026-01-01 to 2026-01-25
77+
```
78+
79+
### Download a specific system
80+
81+
```bash
82+
# Just SEPTA bus data
83+
uv run python scripts/download_data.py --agency septa/bus --date 2026-01-20
84+
85+
# Just SEPTA rail data
86+
uv run python scripts/download_data.py --agency septa/rail --date 2026-01-20
87+
```
88+
6189
## Usage Examples
6290

6391
### Download defaults for a specific date
@@ -92,7 +120,7 @@ uv run python scripts/download_data.py \
92120
|--------|-------------|
93121
| `--list` | List available agencies from inventory |
94122
| `--defaults` | Download AC Transit data for all feed types |
95-
| `--agency AGENCY` | Download all feeds for an agency (e.g., septa, vta) |
123+
| `--agency AGENCY` | Download feeds for an agency (e.g., `septa`) or agency/system (e.g., `septa/bus`) |
96124
| `--date DATE` | Date for `--defaults`/`--agency` mode (default: 2026-01-24) |
97125
| `--feed-type TYPE` | One of: `vehicle_positions`, `trip_updates`, `service_alerts` (advanced) |
98126
| `--feed-url URL` | Plain feed URL (advanced) |

models/staging/stg_available_feeds.sql

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ SELECT
1717
base64url AS feed_base64,
1818
agency_id,
1919
agency_name,
20+
system_id,
21+
system_name,
2022
feed_type,
2123
date_min,
2224
date_max,

scripts/download_data.py

Lines changed: 112 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# Download default sample data (AC Transit, all feed types)
1313
uv run python scripts/download_data.py --defaults
1414
15-
# Download all feeds for a specific agency
15+
# Download all feeds for a specific agency (or agency/system)
1616
uv run python scripts/download_data.py --agency septa --date 2026-01-20
17+
uv run python scripts/download_data.py --agency septa/bus --date 2026-01-20
1718
1819
# Download a specific feed type and date range (advanced)
1920
uv run python scripts/download_data.py \
@@ -57,39 +58,79 @@ def fetch_inventory() -> list[dict]:
5758

5859

5960
def get_agencies(inventory: list[dict]) -> dict:
60-
"""Group inventory by agency_id."""
61+
"""Group inventory by agency_id and system_id."""
6162
agencies = {}
6263
for feed in inventory:
6364
aid = feed["agency_id"]
65+
sid = feed.get("system_id") # None for single-system agencies
66+
6467
if aid not in agencies:
6568
agencies[aid] = {
6669
"name": feed["agency_name"],
70+
"systems": {},
71+
"date_min": feed["date_min"],
72+
"date_max": feed["date_max"],
73+
}
74+
75+
if sid not in agencies[aid]["systems"]:
76+
agencies[aid]["systems"][sid] = {
77+
"name": feed.get("system_name"),
6778
"feeds": {},
6879
"date_min": feed["date_min"],
6980
"date_max": feed["date_max"],
7081
}
71-
# Track feeds and update date range
72-
agencies[aid]["feeds"][feed["feed_type"]] = feed
82+
83+
system = agencies[aid]["systems"][sid]
84+
system["feeds"][feed["feed_type"]] = feed
85+
system["date_min"] = min(system["date_min"], feed["date_min"])
86+
system["date_max"] = max(system["date_max"], feed["date_max"])
87+
88+
# Update agency-level date range
7389
agencies[aid]["date_min"] = min(agencies[aid]["date_min"], feed["date_min"])
7490
agencies[aid]["date_max"] = max(agencies[aid]["date_max"], feed["date_max"])
91+
7592
return agencies
7693

7794

95+
def parse_agency_arg(value: str) -> tuple[str, str | None]:
96+
"""Parse agency argument, e.g. 'septa' or 'septa/bus'."""
97+
if "/" in value:
98+
agency_id, system_id = value.split("/", 1)
99+
return agency_id, system_id
100+
return value, None
101+
102+
78103
def list_agencies(inventory: list[dict]) -> None:
79-
"""Display available agencies."""
104+
"""Display available agencies with system breakdown."""
80105
agencies = get_agencies(inventory)
81106

82107
print("\nAvailable agencies:\n")
83-
print(f" {'Agency ID':<22} {'Agency Name':<32} {'Feeds':<6} {'Date Range'}")
108+
print(f" {'Agency ID':<22} {'Name':<32} {'Feeds':<6} {'Date Range'}")
84109
print(f" {'-' * 22} {'-' * 32} {'-' * 6} {'-' * 24}")
85110

86111
for aid in sorted(agencies.keys()):
87112
info = agencies[aid]
88-
feed_count = len(info["feeds"])
89-
print(f" {aid:<22} {info['name']:<32} {feed_count:<6} {info['date_min']} to {info['date_max']}")
113+
systems = info["systems"]
114+
115+
# Calculate total feeds across all systems
116+
total_feeds = sum(len(s["feeds"]) for s in systems.values())
117+
print(f" {aid:<22} {info['name']:<32} {total_feeds:<6} {info['date_min']} to {info['date_max']}")
118+
119+
# Show system breakdown for multi-system agencies
120+
has_named_systems = any(sid is not None for sid in systems.keys())
121+
if has_named_systems:
122+
for sid in sorted(systems.keys(), key=lambda x: (x is None, x or "")):
123+
if sid is None:
124+
continue
125+
sys_info = systems[sid]
126+
sys_name = sys_info["name"] or sid
127+
feed_count = len(sys_info["feeds"])
128+
sys_key = f"{aid}/{sid}"
129+
print(f" └─ {sys_key:<20} {sys_name:<32} {feed_count:<6} {sys_info['date_min']} to {sys_info['date_max']}")
90130

91131
print(f"\nUse --agency <id> to download all feeds for an agency.")
92-
print(f"Example: uv run python scripts/download_data.py --agency septa --date 2026-01-20")
132+
print(f"Use --agency <id>/<system> to download a specific system.")
133+
print(f"Example: uv run python scripts/download_data.py --agency septa/bus --date 2026-01-20")
93134

94135

95136
def format_size(bytes_size: int) -> str:
@@ -157,11 +198,17 @@ def download_feed_data(
157198
return downloaded, skipped
158199

159200

160-
def download_agency(agency_id: str, date: str, output_dir: Path, inventory: list[dict]) -> dict[str, tuple[int, int]]:
161-
"""Download all feeds for an agency.
201+
def download_agency(
202+
agency_id: str,
203+
date: str,
204+
output_dir: Path,
205+
inventory: list[dict],
206+
system_id: str | None = None,
207+
) -> dict[str, tuple[int, int]]:
208+
"""Download feeds for an agency, optionally filtered by system.
162209
163210
Returns:
164-
Dict mapping feed_type to (downloaded, skipped) counts
211+
Dict mapping feed key to (downloaded, skipped) counts
165212
"""
166213
agencies = get_agencies(inventory)
167214

@@ -172,37 +219,65 @@ def download_agency(agency_id: str, date: str, output_dir: Path, inventory: list
172219
return {}
173220

174221
agency = agencies[agency_id]
175-
feeds = agency["feeds"]
222+
systems = agency["systems"]
223+
224+
# Filter to specific system if requested
225+
if system_id is not None:
226+
if system_id not in systems:
227+
available_systems = [s for s in systems.keys() if s is not None]
228+
if available_systems:
229+
print(f"Error: Unknown system '{system_id}' for agency '{agency_id}'")
230+
print(f"Available systems: {', '.join(sorted(available_systems))}")
231+
else:
232+
print(f"Error: Agency '{agency_id}' has no named systems")
233+
return {}
234+
systems_to_download = {system_id: systems[system_id]}
235+
else:
236+
systems_to_download = systems
176237

177238
# Validate date is in range
178239
if date < agency["date_min"] or date > agency["date_max"]:
179240
print(f"Warning: Date {date} is outside available range ({agency['date_min']} to {agency['date_max']})")
180241

181-
# Estimate sizes
242+
# Collect all feeds to download
243+
all_feeds = {}
244+
for sid, sys_info in systems_to_download.items():
245+
for feed_type, feed in sys_info["feeds"].items():
246+
# Use compound key to avoid collisions
247+
key = f"{sid or 'default'}:{feed_type}"
248+
all_feeds[key] = (sid, feed_type, feed)
249+
250+
# Estimate sizes and display plan
182251
total_bytes = 0
183-
print(f"\nDownloading {agency['name']} data for {date}:")
184-
for feed_type in sorted(feeds.keys()):
185-
feed = feeds[feed_type]
186-
# Estimate per-day size (total_bytes / days in range)
252+
system_label = f" ({system_id})" if system_id else ""
253+
print(f"\nDownloading {agency['name']}{system_label} data for {date}:")
254+
255+
for key in sorted(all_feeds.keys()):
256+
sid, feed_type, feed = all_feeds[key]
187257
days_available = (datetime.strptime(feed["date_max"], "%Y-%m-%d") -
188258
datetime.strptime(feed["date_min"], "%Y-%m-%d")).days + 1
189259
estimated_size = feed["total_bytes"] // max(days_available, 1)
190260
total_bytes += estimated_size
191-
print(f" {feed_type}: ~{format_size(estimated_size)}")
261+
262+
sys_label = f" [{sid}]" if sid and len(systems_to_download) > 1 else ""
263+
print(f" {feed_type}{sys_label}: ~{format_size(estimated_size)}")
264+
192265
print(f" Total: ~{format_size(total_bytes)}")
193266

267+
# Download feeds
194268
results = {}
195-
for feed_type in sorted(feeds.keys()):
196-
feed = feeds[feed_type]
197-
print(f"\n{feed_type}:")
269+
for key in sorted(all_feeds.keys()):
270+
sid, feed_type, feed = all_feeds[key]
271+
sys_label = f" [{sid}]" if sid and len(systems_to_download) > 1 else ""
272+
print(f"\n{feed_type}{sys_label}:")
198273
downloaded, skipped = download_feed_data(
199274
feed_type=feed_type,
200275
feed_base64=feed["base64url"],
201276
start_date=date,
202277
end_date=date,
203278
output_dir=output_dir,
204279
)
205-
results[feed_type] = (downloaded, skipped)
280+
results[key] = (downloaded, skipped)
206281

207282
return results
208283

@@ -213,9 +288,16 @@ def print_summary(results: dict[str, tuple[int, int]], output_dir: Path) -> None
213288
print("Summary:")
214289
total_downloaded = 0
215290
total_skipped = 0
216-
for feed_type, (downloaded, skipped) in results.items():
291+
for key, (downloaded, skipped) in sorted(results.items()):
292+
# Parse compound key if present (e.g., "bus:vehicle_positions")
293+
if ":" in key:
294+
sys_id, feed_type = key.split(":", 1)
295+
display_name = f"{feed_type} [{sys_id}]" if sys_id != "default" else feed_type
296+
else:
297+
display_name = key
298+
217299
status = "✓" if downloaded > 0 or skipped > 0 else "✗"
218-
print(f" {status} {feed_type}: {downloaded} downloaded, {skipped} skipped")
300+
print(f" {status} {display_name}: {downloaded} downloaded, {skipped} skipped")
219301
total_downloaded += downloaded
220302
total_skipped += skipped
221303

@@ -240,8 +322,9 @@ def main():
240322
# Download sample data (AC Transit, all feed types)
241323
%(prog)s --defaults
242324
243-
# Download all feeds for a specific agency
325+
# Download all feeds for a specific agency (or agency/system)
244326
%(prog)s --agency septa --date 2026-01-20
327+
%(prog)s --agency septa/bus --date 2026-01-20
245328
246329
# Download a specific feed (advanced)
247330
%(prog)s --feed-type vehicle_positions \\
@@ -267,7 +350,7 @@ def main():
267350
)
268351
parser.add_argument(
269352
"--agency",
270-
help="Download all feeds for an agency (e.g., actransit, septa, vta)",
353+
help="Download feeds for an agency (e.g., septa) or agency/system (e.g., septa/bus)",
271354
)
272355
parser.add_argument(
273356
"--date",
@@ -356,7 +439,8 @@ def main():
356439
if not inventory:
357440
print("Error: Could not fetch inventory. Check your internet connection.")
358441
return
359-
results = download_agency(args.agency, args.date, args.output_dir, inventory)
442+
agency_id, system_id = parse_agency_arg(args.agency)
443+
results = download_agency(agency_id, args.date, args.output_dir, inventory, system_id)
360444
if results:
361445
print_summary(results, args.output_dir)
362446
return

0 commit comments

Comments
 (0)