Skip to content

Commit 07016bb

Browse files
committed
Fixed unit test cases
1 parent 40dc3f1 commit 07016bb

File tree

6 files changed

+262
-65
lines changed

6 files changed

+262
-65
lines changed

src/osm_osw_reformatter/serializer/osm/osm_graph.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,9 @@ def to_geojson(self, *args) -> None:
632632
polygon_features = []
633633
for n, d in self.G.nodes(data=True):
634634
d_copy = {**d}
635-
d_copy["_id"] = str(n)[1:]
635+
id_str = str(n)
636+
trimmed_id = id_str[1:] if isinstance(n, str) else id_str
637+
d_copy["_id"] = trimmed_id
636638
d_copy['ext:osm_id'] = str(d_copy.get('osm_id', d_copy["_id"]))
637639

638640
if OSWPointNormalizer.osw_point_filter(d):
@@ -665,7 +667,7 @@ def to_geojson(self, *args) -> None:
665667
polygon_features.append(
666668
{"type": "Feature", "geometry": geometry, "properties": d_copy}
667669
)
668-
else:
670+
elif OSWNodeNormalizer.osw_node_filter(d) or self.G.degree(n) > 0:
669671
d_copy['_id'] = str(n)
670672

671673
geometry = mapping(d_copy.pop('geometry'))

src/osm_osw_reformatter/serializer/osm/osm_normalizer.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -87,22 +87,46 @@ def process_output(self, osmnodes, osmways, osmrelations):
8787
"""
8888
mask_63bit = (1 << 63) - 1
8989

90+
def _set_id_tag(osm_obj, new_id):
91+
tags = getattr(osm_obj, "tags", None)
92+
if tags is None or not hasattr(tags, "__setitem__"):
93+
return
94+
95+
value = str(new_id)
96+
existing = tags.get("_id") if hasattr(tags, "get") else None
97+
98+
if isinstance(existing, list):
99+
tags["_id"] = [value]
100+
elif existing is None:
101+
# Determine if the container generally stores values as lists
102+
sample_value = None
103+
if hasattr(tags, "values"):
104+
for sample_value in tags.values():
105+
if sample_value is not None:
106+
break
107+
if isinstance(sample_value, list):
108+
tags["_id"] = [value]
109+
else:
110+
# Default to list storage to match ogr2osm's internal structures
111+
tags["_id"] = [value]
112+
else:
113+
tags["_id"] = value
114+
115+
def _normalise_id(osm_obj):
116+
if osm_obj.id < 0:
117+
new_id = osm_obj.id & mask_63bit
118+
osm_obj.id = new_id
119+
_set_id_tag(osm_obj, new_id)
120+
return new_id
121+
return osm_obj.id
90122

91123
# Fix node IDs
92124
for node in osmnodes:
93-
if node.id < 0:
94-
if mask_63bit == 3964026:
95-
print('node')
96-
print(node.id)
97-
node.id = node.id & mask_63bit
125+
_normalise_id(node)
98126

99127
# Fix ways and their node references
100128
for way in osmways:
101-
if way.id < 0:
102-
if mask_63bit == 3964026:
103-
print('node')
104-
print(way.id)
105-
way.id = way.id & mask_63bit
129+
_normalise_id(way)
106130

107131
# Detect how node references are stored
108132
node_refs = getattr(way, "nds", None) or getattr(way, "refs", None) or getattr(way, "nodeRefs", None) or getattr(way, "nodes", None)
@@ -115,10 +139,8 @@ def process_output(self, osmnodes, osmways, osmrelations):
115139
new_refs.append(ref & mask_63bit if ref < 0 else ref)
116140
elif hasattr(ref, "id"):
117141
if ref.id < 0:
118-
if mask_63bit == 3964026:
119-
print('ref')
120-
print(ref.id)
121142
ref.id = ref.id & mask_63bit
143+
_set_id_tag(ref, ref.id)
122144
new_refs.append(ref)
123145
else:
124146
new_refs.append(ref)
@@ -136,24 +158,25 @@ def process_output(self, osmnodes, osmways, osmrelations):
136158
# Fix relation IDs and their member refs
137159
for rel in osmrelations:
138160
if rel.id < 0:
139-
if mask_63bit == 3964026:
140-
print('rel')
141-
print(rel.id)
142161
rel.id = rel.id & mask_63bit
162+
_normalise_id(rel)
143163

144164
if hasattr(rel, "members"):
145165
for member in rel.members:
146166
if hasattr(member, "ref"):
147167
ref = member.ref
148168
if isinstance(ref, int) and ref < 0:
149-
if mask_63bit == 3964026:
150-
print('members ref 1')
151-
print(ref.id)
152169
member.ref = ref & mask_63bit
153170
elif hasattr(ref, "id") and ref.id < 0:
154-
if mask_63bit == 3964026:
155-
print('members ref 2')
156-
print(ref.id)
157171
ref.id = ref.id & mask_63bit
172+
_set_id_tag(ref, ref.id)
173+
174+
# Ensure deterministic ordering now that IDs have been normalised
175+
if hasattr(osmnodes, "sort"):
176+
osmnodes.sort(key=lambda n: n.id)
177+
if hasattr(osmways, "sort"):
178+
osmways.sort(key=lambda w: w.id)
179+
if hasattr(osmrelations, "sort"):
180+
osmrelations.sort(key=lambda r: r.id)
158181

159182

tests/unit_tests/test_osm_compliance/test_osm_compliance.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,17 @@ async def test_incline_tag_preserved(self):
5151
osw_files = res.generated_files
5252

5353
found_incline = False
54-
if osw_files:
55-
for f in osw_files:
56-
if f.endswith('.geojson'):
57-
with open(f) as fh:
58-
data = json.load(fh)
59-
for feature in data.get('features', []):
60-
props = feature.get('properties', {})
61-
if 'incline' in props:
62-
found_incline = True
63-
break
64-
if found_incline:
65-
break
54+
for f in osw_files:
55+
if f.endswith('.geojson'):
56+
with open(f) as fh:
57+
data = json.load(fh)
58+
for feature in data.get('features', []):
59+
props = feature.get('properties', {})
60+
if 'incline' in props:
61+
found_incline = True
62+
break
63+
if found_incline:
64+
break
6665

6766
self.assertTrue(found_incline, 'No incline tag found in OSW output')
6867

tests/unit_tests/test_serializer/test_osm_graph.py

Lines changed: 149 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,70 @@ def filter_func(u, v, d):
5353
list(filtered_graph.get_graph().edges(data=True))[0][2]["property"], "A"
5454
)
5555

56+
def test_node_adds_tagged_node_with_coordinates(self):
57+
class DummyNode:
58+
def __init__(self, node_id, tags, lon, lat):
59+
self.id = node_id
60+
self.tags = tags
61+
62+
class Location:
63+
def __init__(self, lon, lat):
64+
self.lon = lon
65+
self.lat = lat
66+
67+
self.location = Location(lon, lat)
68+
69+
tagged = DummyNode(101, {"highway": "footway"}, -122.3, 47.6)
70+
71+
self.osm_graph.node(tagged)
72+
73+
self.assertIn(101, self.mock_graph.nodes)
74+
node_data = self.mock_graph.nodes[101]
75+
self.assertEqual(node_data["highway"], "footway")
76+
self.assertEqual(node_data["lon"], -122.3)
77+
self.assertEqual(node_data["lat"], 47.6)
78+
79+
def test_node_skips_existing_identifiers(self):
80+
self.mock_graph.add_node(202, lon=0.0, lat=0.0, name="original")
81+
82+
class DummyNode:
83+
def __init__(self):
84+
self.id = 202
85+
self.tags = {"name": "replacement"}
86+
87+
class Location:
88+
lon = 1.0
89+
lat = 1.0
90+
91+
self.location = Location()
92+
93+
duplicate = DummyNode()
94+
95+
self.osm_graph.node(duplicate)
96+
97+
node_data = self.mock_graph.nodes[202]
98+
self.assertEqual(node_data["name"], "original")
99+
self.assertEqual(node_data["lon"], 0.0)
100+
self.assertEqual(node_data["lat"], 0.0)
101+
102+
def test_node_requires_tags(self):
103+
class DummyNode:
104+
def __init__(self):
105+
self.id = 303
106+
self.tags = {}
107+
108+
class Location:
109+
lon = 2.0
110+
lat = 2.0
111+
112+
self.location = Location()
113+
114+
empty = DummyNode()
115+
116+
self.osm_graph.node(empty)
117+
118+
self.assertNotIn(303, self.mock_graph.nodes)
119+
56120
@patch('src.osm_osw_reformatter.serializer.osm.osm_graph.mapping')
57121
def test_to_geojson(self, mock_mapping):
58122
# Mock mapping function to return a sample GeoJSON structure
@@ -457,16 +521,29 @@ def test_construct_geometries_line_node(self):
457521
class TestFromGeoJSON(unittest.TestCase):
458522
def setUp(self):
459523
# Create valid test GeoJSON files
460-
self.nodes_path = "test_nodes.geojson"
461-
self.edges_path = "test_edges.geojson"
524+
self.tempdir = TemporaryDirectory()
525+
self.nodes_path = os.path.join(self.tempdir.name, "nodes.geojson")
526+
self.edges_path = os.path.join(self.tempdir.name, "edges.geojson")
462527

463-
self.node_data = {
528+
def tearDown(self):
529+
self.tempdir.cleanup()
530+
531+
def _write_geojson(self, node_data, edge_data):
532+
with open(self.nodes_path, "w") as f:
533+
json.dump(node_data, f)
534+
535+
with open(self.edges_path, "w") as f:
536+
json.dump(edge_data, f)
537+
538+
539+
def test_from_geojson_populates_graph(self):
540+
node_data = {
464541
"type": "FeatureCollection",
465542
"features": [
466543
{
467544
"type": "Feature",
468545
"geometry": {"type": "Point", "coordinates": [1, 1]},
469-
"properties": {"_id": "1", "attribute": "value1"},
546+
"properties": {"_id": "1", "attribute": "value1", "ext:osm_id": "11"},
470547
},
471548
{
472549
"type": "Feature",
@@ -476,45 +553,88 @@ def setUp(self):
476553
],
477554
}
478555

479-
self.edge_data = {
556+
edge_data = {
480557
"type": "FeatureCollection",
481558
"features": [
482559
{
483560
"type": "Feature",
484561
"geometry": {"type": "LineString", "coordinates": [[1, 1], [2, 2]]},
485-
"properties": {"_id": "1", "_u_id": "1", "_v_id": "2", "attribute": "edge_value"},
562+
"properties": {
563+
"_id": "5",
564+
"_u_id": "1",
565+
"_v_id": "2",
566+
"attribute": "edge_value",
567+
"ext:osm_id": "99",
568+
},
486569
},
487570
],
488571
}
489572

490-
# Write the data to files
491-
with open(self.nodes_path, "w") as f:
492-
json.dump(self.node_data, f)
573+
self._write_geojson(node_data, edge_data)
574+
osm_graph = OSMGraph.from_geojson(self.nodes_path, self.edges_path)
575+
graph = osm_graph.get_graph()
493576

494-
with open(self.edges_path, "w") as f:
495-
json.dump(self.edge_data, f)
577+
self.assertEqual(set(graph.nodes), {1, 2})
578+
node_attrs = graph.nodes[1]
579+
self.assertIsInstance(node_attrs["geometry"], Point)
580+
self.assertEqual(node_attrs["lon"], 1)
581+
self.assertEqual(node_attrs["lat"], 1)
582+
self.assertEqual(node_attrs["osm_id"], 11)
496583

497-
def tearDown(self):
498-
# Clean up files after tests
499-
import os
500-
if os.path.exists(self.nodes_path):
501-
os.remove(self.nodes_path)
502-
if os.path.exists(self.edges_path):
503-
os.remove(self.edges_path)
504-
505-
@patch("src.osm_osw_reformatter.serializer.osm.osm_graph.OSMGraph.from_geojson")
506-
def test_from_geojson(self, mock_from_geojson):
507-
mock_graph = MagicMock()
508-
mock_graph.get_graph.return_value.nodes = {"1": {"geometry": Point(1, 1)}}
509-
mock_graph.get_graph.return_value.edges = {("1", "2"): {"geometry": LineString([(1, 1), (2, 2)])}}
510-
mock_from_geojson.return_value = mock_graph
584+
edges = list(graph.edges(keys=True, data=True))
585+
self.assertEqual(len(edges), 1)
586+
u, v, key, attrs = edges[0]
587+
self.assertEqual((u, v, key), (1, 2, 5))
588+
self.assertIsInstance(attrs["geometry"], LineString)
589+
self.assertEqual(attrs["osm_id"], 99)
590+
self.assertEqual(attrs["attribute"], "edge_value")
591+
592+
def test_from_geojson_preserves_non_numeric_identifiers(self):
593+
node_data = {
594+
"type": "FeatureCollection",
595+
"features": [
596+
{
597+
"type": "Feature",
598+
"geometry": {"type": "Point", "coordinates": [3, 3]},
599+
"properties": {"_id": "p123", "ext:osm_id": "node-1"},
600+
},
601+
{
602+
"type": "Feature",
603+
"geometry": {"type": "Point", "coordinates": [4, 4]},
604+
"properties": {"_id": "p456"},
605+
},
606+
],
607+
}
608+
609+
edge_data = {
610+
"type": "FeatureCollection",
611+
"features": [
612+
{
613+
"type": "Feature",
614+
"geometry": {"type": "LineString", "coordinates": [[3, 3], [4, 4]]},
615+
"properties": {
616+
"_id": "edge-7",
617+
"_u_id": "p123",
618+
"_v_id": "p456",
619+
"ext:osm_id": "edge-A",
620+
},
621+
},
622+
],
623+
}
624+
625+
self._write_geojson(node_data, edge_data)
511626

512627
osm_graph = OSMGraph.from_geojson(self.nodes_path, self.edges_path)
628+
graph = osm_graph.get_graph()
629+
630+
self.assertIn("p123", graph.nodes)
631+
self.assertEqual(graph.nodes["p123"]["osm_id"], "node-1")
513632

514-
# Assertions
515-
self.assertIsNotNone(osm_graph, "OSMGraph object should not be None")
516-
self.assertEqual(len(osm_graph.get_graph().nodes), 1)
517-
self.assertEqual(len(osm_graph.get_graph().edges), 1)
633+
edges = list(graph.edges(keys=True, data=True))
634+
self.assertEqual(len(edges), 1)
635+
u, v, key, attrs = edges[0]
636+
self.assertEqual((u, v, key), ("p123", "p456", "edge-7"))
637+
self.assertEqual(attrs["osm_id"], "edge-A")
518638

519639
def test_tagged_node_parser_skips_non_osw_nodes(self):
520640
graph = nx.MultiDiGraph()

0 commit comments

Comments
 (0)