Skip to content

Commit 2b4d42d

Browse files
authored
Merge pull request #30 from jeafreezy/cli-bug-fix
Feat: CLI support + Enhancements
2 parents 4ceb58c + 2136458 commit 2b4d42d

File tree

8 files changed

+270
-69
lines changed

8 files changed

+270
-69
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,24 +107,24 @@ python -m graphfaker.cli --help
107107

108108
#### Generate a Synthetic Social Graph
109109
```sh
110-
python -m graphfaker.cli gen \
111-
--source faker \
110+
python -m graphfaker.cli \
111+
--fetcher faker \
112112
--total-nodes 100 \
113113
--total-edges 500
114114
```
115115

116116
#### Generate a Real-World Road Network (OSM)
117117
```sh
118-
python -m graphfaker.cli gen \
119-
--source osm \
118+
python -m graphfaker.cli \
119+
--fetcher osm \
120120
--place "Berlin, Germany" \
121121
--network-type drive
122122
```
123123

124124
#### Generate a Flight Network (Airlines/Airports/Flights)
125125
```sh
126-
python -m graphfaker.cli gen \
127-
--source flights \
126+
python -m graphfaker.cli \
127+
--fetcher flights \
128128
--country "United States" \
129129
--year 2024 \
130130
--month 1

graphfaker/cli.py

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,72 @@
1-
# graphfaker/cli.py
21
"""
32
Command-line interface for GraphFaker.
43
"""
4+
55
from venv import logger
66
import typer
77
from graphfaker.core import GraphFaker
8+
from graphfaker.enums import FetcherType
89
from graphfaker.fetchers.osm import OSMGraphFetcher
910
from graphfaker.fetchers.flights import FlightGraphFetcher
11+
from graphfaker.utils import parse_date_range
1012

1113
app = typer.Typer()
1214

1315

14-
@app.command()
16+
@app.command(short_help="Generate a graph using GraphFaker.")
1517
def gen(
16-
source: str = typer.Option("faker", help="Generation source faker library)"),
17-
# for faker source
18-
total_nodes: int = typer.Option(100, help="Total nodes for random mode"),
19-
total_edges: int = typer.Option(1000, help="Total edges for random mode"),
20-
# for osm source
18+
fetcher: FetcherType = typer.Option(FetcherType.FAKER, help="Fetcher type to use."),
19+
# for FetcherType.FAKER source
20+
total_nodes: int = typer.Option(100, help="Total nodes for random mode."),
21+
total_edges: int = typer.Option(1000, help="Total edges for random mode."),
22+
# for FetcherType.OSM source
2123
place: str = typer.Option(
22-
None, help="OSM place name (e.g., 'Soho Square, London, UK')"
24+
None, help="OSM place name (e.g., 'Soho Square, London, UK')."
2325
),
2426
address: str = typer.Option(
25-
None, help="OSM address (e.g., '1600 Amphitheatre Parkway, Mountain View, CA')"
27+
None, help="OSM address (e.g., '1600 Amphitheatre Parkway, Mountain View, CA.')"
2628
),
27-
bbox: str = typer.Option(None, help="OSM bounding box as 'north,south,east,west'"),
29+
bbox: str = typer.Option(None, help="OSM bounding box as 'north,south,east,west.'"),
2830
network_type: str = typer.Option(
29-
"drive", help="OSM network type: drive | walk | bike | all"
31+
"drive", help="OSM network type: drive | walk | bike | all."
3032
),
31-
simplify: str = typer.Option(True, help="Simplify OSM graph topology"),
32-
retain_all: bool = typer.Option(False, help="Retain all components in OSM graph"),
33+
simplify: bool = typer.Option(True, help="Simplify OSM graph topology."),
34+
retain_all: bool = typer.Option(False, help="Retain all components in OSM graph."),
3335
dist: int = typer.Option(
34-
1000, help="Search radius (meters) when fetching around address"
36+
1000, help="Search radius (meters) when fetching around address."
3537
),
36-
# for flight source
38+
# for FetcherType.FLIGHT source
3739
country: str = typer.Option(
38-
"United States", help="Filter airports by country for flight data"
40+
"United States",
41+
help="Filter airports by country for flight data. e.g 'United States'.",
42+
),
43+
year: int = typer.Option(
44+
2024, help="Year (YYYY) for single-month flight fetch. e.g. 2024."
45+
),
46+
month: int = typer.Option(
47+
1, help="Month (1-12) for single-month flight fetch. e.g. 1 for January."
48+
),
49+
date_range: str = typer.Option(
50+
None,
51+
help="Year, Month and day range (YYYY-MM-DD,YYYY-MM-DD) for flight data. e.g. '2024-01-01,2024-01-15'.",
3952
),
40-
year: int = typer.Option(2024, help="Year (YYYY) for single-month flight fetch"),
41-
month: int = typer.Option(1, help="Month (1-12) for single-month flight fetch"),
42-
date_range: tuple = typer.Option(None, help="Year and Month range for flight data"),
4353
):
4454
"""Generate a graph using GraphFaker."""
4555
gf = GraphFaker()
46-
if source == "faker":
47-
G = gf.generate_graph(total_nodes=total_nodes, total_edges=total_edges)
4856

49-
elif source == "osm":
57+
if fetcher == FetcherType.FAKER:
58+
59+
g = gf.generate_graph(total_nodes=total_nodes, total_edges=total_edges)
60+
print(g)
61+
return g
62+
63+
elif fetcher == FetcherType.OSM:
5064
# parse bbox string if provided
5165
bbox_tuple = None
5266
if bbox:
5367
north, south, east, west = map(float, bbox.split(","))
5468
bbox_tuple = (north, south, east, west)
55-
G = OSMGraphFetcher.fetch_network(
69+
g = OSMGraphFetcher.fetch_network(
5670
place=place,
5771
address=address,
5872
bbox=bbox_tuple,
@@ -61,28 +75,34 @@ def gen(
6175
retain_all=retain_all,
6276
dist=dist,
6377
)
64-
elif source == "flights":
78+
print(g)
79+
return g
80+
else:
81+
# Flight fetcher
82+
parsed_date_range = parse_date_range(date_range) if date_range else None
83+
84+
# validate year and month
85+
if not (1 <= month <= 12):
86+
raise ValueError("Month must be between 1 and 12.")
87+
if not (1900 <= year <= 2100):
88+
raise ValueError("Year must be between 1900 and 2100.")
6589

6690
airlines_df = FlightGraphFetcher.fetch_airlines()
6791

6892
airports_df = FlightGraphFetcher.fetch_airports(country=country)
6993

70-
# 2) Fetch on-time performance data
7194
flights_df = FlightGraphFetcher.fetch_flights(
72-
year=year, month=month, date_range=date_range
95+
year=year, month=month, date_range=parsed_date_range
7396
)
7497
logger.info(
7598
f"Fetched {len(airlines_df)} airlines, "
7699
f"{len(airports_df)} airports, "
77100
f"{len(flights_df)} flights."
78101
)
79102

80-
# 3) Build the NetworkX graph
81-
G = FlightGraphFetcher.build_graph(airlines_df, airports_df, flights_df)
82-
83-
else:
84-
typer.echo(f"Source '{source}' not supported.")
85-
raise typer.Exit(code=1)
103+
g = FlightGraphFetcher.build_graph(airlines_df, airports_df, flights_df)
104+
print(g)
105+
return g
86106

87107

88108
if __name__ == "__main__":

graphfaker/core.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
A multi-domain network connecting entities across social, geographical, and commercial dimensions.
33
"""
44

5+
from typing import Optional
56
import networkx as nx
67
import random
78
from faker import Faker
@@ -203,9 +204,9 @@ def generate_edges(self, total_edges=1000):
203204

204205
def _generate_osm(
205206
self,
206-
place: str = None,
207-
address: str = None,
208-
bbox: tuple = None,
207+
place: Optional[str] = None,
208+
address: Optional[str] = None,
209+
bbox: Optional[tuple] = None,
209210
network_type: str = "drive",
210211
simplify: bool = True,
211212
retain_all: bool = False,
@@ -219,17 +220,17 @@ def _generate_osm(
219220
network_type=network_type,
220221
simplify=simplify,
221222
retain_all=retain_all,
222-
dist=dist
223+
dist=dist,
223224
)
224225
self.G = G
225226
return G
226227

227228
def _generate_flights(
228229
self,
229230
country: str = "United States",
230-
year: int = None,
231-
month: int = None,
232-
date_range: tuple = None,
231+
year: Optional[int] = None,
232+
month: Optional[int] = None,
233+
date_range: Optional[tuple] = None,
233234
):
234235
"""
235236
Fetch flights, airport, and airline via FlightFetcher
@@ -250,14 +251,15 @@ def _generate_flights(
250251

251252
G = FlightGraphFetcher.build_graph(airlines_df, airports_df, flights_df)
252253
self.G = G
253-
return G
254+
254255
# Inform users of which span was downloaded
255256
if date_range:
256257
start, end = date_range
257258
logger.info(f"Flight data covers {start} -> {end}")
258259

259260
else:
260261
logger.info(f"Flight data for {year}-{month:02d}")
262+
return G
261263

262264
def _generate_faker(self, total_nodes=100, total_edges=1000):
263265
"""Generates the complete Social Knowledge Graph."""
@@ -270,28 +272,34 @@ def generate_graph(
270272
source: str = "faker",
271273
total_nodes: int = 100,
272274
total_edges: int = 1000,
273-
place: str = None,
274-
address: str = None,
275-
bbox: tuple = None,
275+
place: Optional[str] = None,
276+
address: Optional[str] = None,
277+
bbox: Optional[tuple] = None,
276278
network_type: str = "drive",
277279
simplify: bool = True,
278280
retain_all: bool = False,
279281
dist: float = 1000,
280282
country: str = "United States",
281283
year: int = 2024,
282284
month: int = 1,
283-
date_range: tuple = None,
285+
date_range: Optional[tuple] = None,
284286
) -> nx.DiGraph:
285287
"""
286288
Unified entrypoint: choose 'random' or 'osm'.
287289
Pass kwargs depending on source.
288290
"""
291+
289292
if source == "faker":
290293
return self._generate_faker(
291-
total_nodes=total_nodes,
292-
total_edges=total_edges
294+
total_nodes=total_nodes, total_edges=total_edges
293295
)
294296
elif source == "osm":
297+
logger.info(
298+
f"Generating OSM graph with source={source}, "
299+
f"place={place}, address={address}, bbox={bbox}, "
300+
f"network_type={network_type}, simplify={simplify}, "
301+
f"retain_all={retain_all}, dist={dist}"
302+
)
295303
return self._generate_osm(
296304
place=place,
297305
address=address,

graphfaker/enums.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from enum import Enum
2+
3+
4+
class FetcherType(str, Enum):
5+
"""Enum for different fetcher types."""
6+
7+
OSM = "osm"
8+
FLIGHTS = "flights"
9+
FAKER = "faker"

graphfaker/fetchers/flights.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import pandas as pd
3434
from tqdm.auto import tqdm
3535
import networkx as nx
36-
import time
36+
3737

3838
# suppress only the single warning from unverified HTTPS
3939
import urllib3
@@ -124,6 +124,7 @@ def fetch_airlines() -> pd.DataFrame:
124124
Raises:
125125
HTTPError if download fails.
126126
"""
127+
logger.info("Fetching airlines lookup from BTS…")
127128
resp = requests.get(AIRLINE_LOOKUP_URL, verify=False)
128129
resp.raise_for_status()
129130
df = pd.read_csv(StringIO(resp.text))
@@ -145,6 +146,7 @@ def fetch_airports(
145146
Returns:
146147
pd.DataFrame with columns ['faa','name','city','country','lat','lon']
147148
"""
149+
logger.info("Fetching airports dataset from OpenFlights…")
148150
df = pd.read_csv(
149151
AIRPORTS_URL,
150152
header=None,
@@ -220,6 +222,10 @@ def fetch_flights(
220222
Raises:
221223
ValueError if neither valid year/month nor date_range provided.
222224
"""
225+
logger.info(
226+
f"Fetching flight performance data for {year}-{month:02d} "
227+
f"or date range {date_range}…"
228+
)
223229

224230
def load_month(y, m):
225231
url = f"https://transtats.bts.gov/PREZIP/On_Time_Reporting_Carrier_On_Time_Performance_1987_present_{y}_{m}.zip"

0 commit comments

Comments
 (0)