Skip to content

Commit bf4ee06

Browse files
authored
Refactor get_cartography return values and add to_folium_map functions (#393)
* Change return of get_cartography function * Add to_folium_map function * Update requirements * Fix * Bump version
1 parent 3565398 commit bf4ee06

File tree

6 files changed

+131
-34
lines changed

6 files changed

+131
-34
lines changed

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ networkx
55
numpy
66
osmnx
77
pandas
8-
Pillow
98
seaborn
109
tqdm
11-
opencv-python
1210
shapely
11+
folium

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,5 +521,6 @@ def run_stubgen(self):
521521
"numpy",
522522
"geopandas",
523523
"shapely",
524+
"folium",
524525
],
525526
)

src/dsf/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@
2121
graph_from_gdfs,
2222
graph_to_gdfs,
2323
create_manhattan_cartography,
24+
to_folium_map,
2425
)

src/dsf/dsf.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
static constexpr uint8_t DSF_VERSION_MAJOR = 4;
88
static constexpr uint8_t DSF_VERSION_MINOR = 7;
9-
static constexpr uint8_t DSF_VERSION_PATCH = 2;
9+
static constexpr uint8_t DSF_VERSION_PATCH = 3;
1010

1111
static auto const DSF_VERSION =
1212
std::format("{}.{}.{}", DSF_VERSION_MAJOR, DSF_VERSION_MINOR, DSF_VERSION_PATCH);

src/dsf/python/cartography.py

Lines changed: 65 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
standardization of attributes.
88
"""
99

10+
import folium
1011
import geopandas as gpd
1112
import networkx as nx
1213
import numpy as np
@@ -21,8 +22,7 @@ def get_cartography(
2122
consolidate_intersections: bool | float = 10,
2223
dead_ends: bool = False,
2324
infer_speeds: bool = False,
24-
return_type: str = "gdfs",
25-
) -> tuple | nx.DiGraph:
25+
) -> tuple[nx.DiGraph, gpd.GeoDataFrame, gpd.GeoDataFrame]:
2626
"""
2727
Retrieves and processes cartography data for a specified place using OpenStreetMap data.
2828
@@ -44,16 +44,14 @@ def get_cartography(
4444
infer_speeds (bool, optional): Whether to infer edge speeds based on road types. Defaults to False.
4545
If True, calls ox.routing.add_edge_speeds using np.nanmedian as aggregation function.
4646
Finally, the "maxspeed" attribute is replaced with the inferred "speed_kph", and the "travel_time" attribute is computed.
47-
return_type (str, optional): Type of return value. Options are "gdfs" (GeoDataFrames) or
48-
"graph" (NetworkX DiGraph). Defaults to "gdfs".
4947
5048
Returns:
51-
tuple | nx.DiGraph: If return_type is "gdfs", returns a tuple containing two GeoDataFrames:
49+
tuple[nx.DiGraph, gpd.GeoDataFrame, gpd.GeoDataFrame]: Returns a tuple containing:
50+
- NetworkX DiGraph with standardized attributes.
5251
- gdf_edges: GeoDataFrame with processed edge data, including columns like 'source',
5352
'target', 'nlanes', 'type', 'name', 'id', and 'geometry'.
5453
- gdf_nodes: GeoDataFrame with processed node data, including columns like 'id', 'type',
5554
and 'geometry'.
56-
If return_type is "graph", returns the NetworkX DiGraph with standardized attributes.
5755
"""
5856
if bbox is None and place_name is None:
5957
raise ValueError("Either place_name or bbox must be provided.")
@@ -223,32 +221,26 @@ def get_cartography(
223221
): # Check for NaN
224222
G.nodes[node]["type"] = "N/A"
225223

226-
# Return graph or GeoDataFrames based on return_type
227-
if return_type == "graph":
228-
return G
229-
elif return_type == "gdfs":
230-
# Convert back to MultiDiGraph temporarily for ox.graph_to_gdfs compatibility
231-
gdf_nodes, gdf_edges = ox.graph_to_gdfs(nx.MultiDiGraph(G))
224+
# Convert back to MultiDiGraph temporarily for ox.graph_to_gdfs compatibility
225+
gdf_nodes, gdf_edges = ox.graph_to_gdfs(nx.MultiDiGraph(G))
232226

233-
# Reset index and drop unnecessary columns (id, source, target already exist from graph)
234-
gdf_edges.reset_index(inplace=True)
235-
# Move the "id" column to the beginning
236-
id_col = gdf_edges.pop("id")
237-
gdf_edges.insert(0, "id", id_col)
227+
# Reset index and drop unnecessary columns (id, source, target already exist from graph)
228+
gdf_edges.reset_index(inplace=True)
229+
# Move the "id" column to the beginning
230+
id_col = gdf_edges.pop("id")
231+
gdf_edges.insert(0, "id", id_col)
238232

239-
# Ensure length is float
240-
gdf_edges["length"] = gdf_edges["length"].astype(float)
233+
# Ensure length is float
234+
gdf_edges["length"] = gdf_edges["length"].astype(float)
241235

242-
gdf_edges.drop(columns=["u", "v", "key"], inplace=True, errors="ignore")
236+
gdf_edges.drop(columns=["u", "v", "key"], inplace=True, errors="ignore")
243237

244-
# Reset index for nodes
245-
gdf_nodes.reset_index(inplace=True)
246-
gdf_nodes.drop(columns=["y", "x"], inplace=True, errors="ignore")
247-
gdf_nodes.rename(columns={"osmid": "id"}, inplace=True)
238+
# Reset index for nodes
239+
gdf_nodes.reset_index(inplace=True)
240+
gdf_nodes.drop(columns=["y", "x"], inplace=True, errors="ignore")
241+
gdf_nodes.rename(columns={"osmid": "id"}, inplace=True)
248242

249-
return gdf_edges, gdf_nodes
250-
else:
251-
raise ValueError("Invalid return_type. Choose 'gdfs' or 'graph'.")
243+
return G, gdf_edges, gdf_nodes
252244

253245

254246
def graph_from_gdfs(
@@ -460,6 +452,52 @@ def create_manhattan_cartography(
460452
return gdf_edges, gdf_nodes
461453

462454

455+
def to_folium_map(
456+
G: nx.DiGraph,
457+
which: str = "edges",
458+
) -> folium.Map:
459+
"""
460+
Converts a NetworkX DiGraph to a Folium map for visualization.
461+
Args:
462+
G (nx.DiGraph): The input DiGraph.
463+
which (str): Specify whether to visualize 'edges', 'nodes', or 'both'. Defaults to 'edges'.
464+
Returns:
465+
folium.Map: The Folium map with the graph visualized.
466+
"""
467+
468+
# Compute mean latitude and longitude for centering the map
469+
mean_lat = np.mean([data["geometry"].y for _, data in G.nodes(data=True)])
470+
mean_lon = np.mean([data["geometry"].x for _, data in G.nodes(data=True)])
471+
folium_map = folium.Map(location=[mean_lat, mean_lon], zoom_start=13)
472+
473+
if which in ("edges", "both"):
474+
# Add edges to the map
475+
for _, _, data in G.edges(data=True):
476+
line = data.get("geometry")
477+
if line:
478+
folium.PolyLine(
479+
locations=[(point[1], point[0]) for point in line.coords],
480+
color="blue",
481+
weight=2,
482+
opacity=0.7,
483+
popup=f"Edge ID: {data.get('id')}",
484+
).add_to(folium_map)
485+
if which in ("nodes", "both"):
486+
# Add nodes to the map
487+
for _, data in G.nodes(data=True):
488+
folium.CircleMarker(
489+
location=(data["geometry"].y, data["geometry"].x),
490+
radius=5,
491+
color="red",
492+
fill=True,
493+
fill_color="red",
494+
fill_opacity=0.7,
495+
popup=f"Node ID: {data.get('id')}",
496+
).add_to(folium_map)
497+
498+
return folium_map
499+
500+
463501
# if __name__ == "__main__":
464502
# # Produce data for tests
465503
# edges, nodes = get_cartography(

test/Test_cartography.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44

55
import pytest
66
import networkx as nx
7+
import folium
78
from dsf.python.cartography import (
89
get_cartography,
910
graph_to_gdfs,
1011
graph_from_gdfs,
1112
create_manhattan_cartography,
13+
to_folium_map,
1214
)
1315

1416

@@ -17,10 +19,7 @@ def test_consistency():
1719
A simple consistency test to verify that converting from GeoDataFrames to graph and back
1820
yields the same GeoDataFrames.
1921
"""
20-
G_CART = get_cartography("Postua, Piedmont, Italy", return_type="graph")
21-
edges_cart, nodes_cart = get_cartography(
22-
"Postua, Piedmont, Italy", return_type="gdfs"
23-
)
22+
G_CART, edges_cart, nodes_cart = get_cartography("Postua, Piedmont, Italy")
2423

2524
edges, nodes = graph_to_gdfs(G_CART)
2625

@@ -221,5 +220,64 @@ def test_rectangular_grid(self):
221220
assert len(edges) == expected_edges
222221

223222

223+
class TestToFoliumMap:
224+
"""Tests for to_folium_map function."""
225+
226+
@pytest.fixture
227+
def sample_graph(self):
228+
"""Create a sample graph for testing."""
229+
edges, nodes = create_manhattan_cartography(n_x=3, n_y=3)
230+
return graph_from_gdfs(edges, nodes)
231+
232+
def test_returns_folium_map(self, sample_graph):
233+
"""Test that the function returns a folium.Map object."""
234+
result = to_folium_map(sample_graph)
235+
assert isinstance(result, folium.Map)
236+
237+
def test_edges_only(self, sample_graph):
238+
"""Test visualization with edges only (default)."""
239+
result = to_folium_map(sample_graph, which="edges")
240+
assert isinstance(result, folium.Map)
241+
# Check that the map has children (the edges)
242+
assert len(result._children) > 0
243+
244+
def test_nodes_only(self, sample_graph):
245+
"""Test visualization with nodes only."""
246+
result = to_folium_map(sample_graph, which="nodes")
247+
assert isinstance(result, folium.Map)
248+
assert len(result._children) > 0
249+
250+
def test_both_edges_and_nodes(self, sample_graph):
251+
"""Test visualization with both edges and nodes."""
252+
result = to_folium_map(sample_graph, which="both")
253+
assert isinstance(result, folium.Map)
254+
# Should have more children than edges-only or nodes-only
255+
edges_only = to_folium_map(sample_graph, which="edges")
256+
nodes_only = to_folium_map(sample_graph, which="nodes")
257+
# 'both' should have children from edges and nodes combined
258+
# (minus the base tile layer which is common)
259+
assert len(result._children) >= len(edges_only._children)
260+
assert len(result._children) >= len(nodes_only._children)
261+
262+
def test_map_center_location(self, sample_graph):
263+
"""Test that the map is centered correctly."""
264+
result = to_folium_map(sample_graph)
265+
# The map should be centered around the mean of node coordinates
266+
# For a Manhattan grid centered at (0, 0), the center should be near (0, 0)
267+
location = result.location
268+
assert location is not None
269+
assert len(location) == 2
270+
# Check that location is reasonable (near 0,0 for default manhattan grid)
271+
assert -1 < location[0] < 1 # latitude
272+
assert -1 < location[1] < 1 # longitude
273+
274+
def test_default_which_parameter(self, sample_graph):
275+
"""Test that default 'which' parameter is 'edges'."""
276+
default_result = to_folium_map(sample_graph)
277+
edges_result = to_folium_map(sample_graph, which="edges")
278+
# Both should produce maps with the same number of children
279+
assert len(default_result._children) == len(edges_result._children)
280+
281+
224282
if __name__ == "__main__":
225283
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)