Skip to content

Commit 79cea80

Browse files
authored
feat: Add comprehensive KNN join integration tests and benchmarks (#65)
1 parent 660da68 commit 79cea80

File tree

2 files changed

+603
-133
lines changed

2 files changed

+603
-133
lines changed

benchmarks/test_knn.py

Lines changed: 164 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,20 @@
1717
import json
1818
import pytest
1919
from test_bench_base import TestBenchBase
20-
from sedonadb.testing import SedonaDB
20+
from sedonadb.testing import SedonaDB, PostGIS, DuckDB
2121

2222

2323
class TestBenchKNN(TestBenchBase):
2424
def setup_class(self):
2525
"""Setup test data for KNN benchmarks"""
2626
self.sedonadb = SedonaDB.create_or_skip()
27+
self.postgis = PostGIS.create_or_skip()
28+
self.duckdb = DuckDB.create_or_skip()
2729

2830
# Create building-like polygons (index side - fewer, larger geometries)
29-
# Note: Dataset sizes are limited to avoid performance issues observed when processing
30-
# very large synthetic datasets. Large synthetic datasets have been observed to cause
31-
# memory pressure or performance degradation in DataFusion operations.
3231
building_options = {
3332
"geom_type": "Polygon",
34-
"target_rows": 2_000, # Reasonable size for benchmarking
33+
"target_rows": 2_000,
3534
"vertices_per_linestring_range": [4, 8],
3635
"size_range": [0.001, 0.01],
3736
"seed": 42,
@@ -46,8 +45,10 @@ def setup_class(self):
4645
"""
4746
building_tab = self.sedonadb.execute_and_collect(building_query)
4847
self.sedonadb.create_table_arrow("knn_buildings", building_tab)
48+
self.postgis.create_table_arrow("knn_buildings", building_tab)
49+
self.duckdb.create_table_arrow("knn_buildings", building_tab)
4950

50-
# Create trip pickup points (probe side - many small geometries)
51+
# Create trip pickup points (probe side)
5152
trip_options = {
5253
"geom_type": "Point",
5354
"target_rows": 10_000,
@@ -62,25 +63,31 @@ def setup_class(self):
6263
"""
6364
trip_tab = self.sedonadb.execute_and_collect(trip_query)
6465
self.sedonadb.create_table_arrow("knn_trips", trip_tab)
66+
self.postgis.create_table_arrow("knn_trips", trip_tab)
67+
self.duckdb.create_table_arrow("knn_trips", trip_tab)
6568

6669
# Create a smaller test dataset for quick tests
6770
small_building_query = """
6871
SELECT * FROM knn_buildings LIMIT 1000
6972
"""
7073
small_building_tab = self.sedonadb.execute_and_collect(small_building_query)
7174
self.sedonadb.create_table_arrow("knn_buildings_small", small_building_tab)
75+
self.postgis.create_table_arrow("knn_buildings_small", small_building_tab)
76+
self.duckdb.create_table_arrow("knn_buildings_small", small_building_tab)
7277

7378
small_trip_query = """
7479
SELECT * FROM knn_trips LIMIT 5000
7580
"""
7681
small_trip_tab = self.sedonadb.execute_and_collect(small_trip_query)
7782
self.sedonadb.create_table_arrow("knn_trips_small", small_trip_tab)
83+
self.postgis.create_table_arrow("knn_trips_small", small_trip_tab)
84+
self.duckdb.create_table_arrow("knn_trips_small", small_trip_tab)
7885

7986
@pytest.mark.parametrize("k", [1, 5, 10])
80-
@pytest.mark.parametrize("use_spheroid", [False, True])
87+
@pytest.mark.parametrize("engine", [SedonaDB, PostGIS, DuckDB])
8188
@pytest.mark.parametrize("dataset_size", ["small", "large"])
82-
def test_knn_performance(self, benchmark, k, use_spheroid, dataset_size):
83-
"""Benchmark KNN query performance with different parameters"""
89+
def test_knn_performance(self, benchmark, k, engine, dataset_size):
90+
"""Benchmark KNN query performance comparing SedonaDB vs PostGIS"""
8491

8592
if dataset_size == "small":
8693
trip_table = "knn_trips_small"
@@ -89,138 +96,162 @@ def test_knn_performance(self, benchmark, k, use_spheroid, dataset_size):
8996
else:
9097
trip_table = "knn_trips_small"
9198
building_table = "knn_buildings"
92-
trip_limit = 500
99+
trip_limit = 1000
93100

94-
spheroid_str = "TRUE" if use_spheroid else "FALSE"
101+
# Get the appropriate engine instance
102+
eng = self._get_eng(engine)
95103

96104
def run_knn_query():
97-
query = f"""
98-
WITH trip_sample AS (
99-
SELECT trip_id, geom as trip_geom
100-
FROM {trip_table}
101-
LIMIT {trip_limit}
102-
),
103-
building_with_geom AS (
104-
SELECT building_id, name, geom as building_geom
105-
FROM {building_table}
106-
)
107-
SELECT
108-
t.trip_id,
109-
b.building_id,
110-
b.name,
111-
ST_Distance(t.trip_geom, b.building_geom) as distance
112-
FROM trip_sample t
113-
JOIN building_with_geom b ON ST_KNN(t.trip_geom, b.building_geom, {k}, {spheroid_str})
114-
ORDER BY t.trip_id, distance
115-
"""
116-
result = self.sedonadb.execute_and_collect(query)
117-
return len(result) # Return result count for verification
105+
if engine == SedonaDB:
106+
# SedonaDB syntax using ST_KNN function
107+
query = f"""
108+
WITH trip_sample AS (
109+
SELECT trip_id, geom as trip_geom
110+
FROM {trip_table}
111+
LIMIT {trip_limit}
112+
),
113+
building_with_geom AS (
114+
SELECT building_id, name, geom as building_geom
115+
FROM {building_table}
116+
)
117+
SELECT
118+
t.trip_id,
119+
b.building_id,
120+
b.name,
121+
ST_Distance(t.trip_geom, b.building_geom) as distance
122+
FROM trip_sample t
123+
JOIN building_with_geom b ON ST_KNN(t.trip_geom, b.building_geom, {k}, FALSE)
124+
ORDER BY t.trip_id, distance
125+
"""
126+
elif engine == PostGIS:
127+
# PostGIS syntax using distance operator and window functions
128+
query = f"""
129+
WITH trip_sample AS (
130+
SELECT trip_id, geom as trip_geom
131+
FROM {trip_table}
132+
LIMIT {trip_limit}
133+
),
134+
building_with_geom AS (
135+
SELECT building_id, name, geom as building_geom
136+
FROM {building_table}
137+
),
138+
ranked_neighbors AS (
139+
SELECT
140+
t.trip_id,
141+
b.building_id,
142+
b.name,
143+
ST_Distance(t.trip_geom, b.building_geom) as distance,
144+
ROW_NUMBER() OVER (PARTITION BY t.trip_id ORDER BY t.trip_geom <-> b.building_geom) as rn
145+
FROM trip_sample t
146+
CROSS JOIN building_with_geom b
147+
)
148+
SELECT trip_id, building_id, name, distance
149+
FROM ranked_neighbors
150+
WHERE rn <= {k}
151+
ORDER BY trip_id, distance
152+
"""
153+
else: # DuckDB
154+
# DuckDB KNN simulation using spatial joins with distance predicates
155+
# Since DuckDB doesn't have native KNN, we use a cross join with distance calculation and ranking
156+
query = f"""
157+
WITH trip_sample AS (
158+
SELECT trip_id, geom as trip_geom
159+
FROM {trip_table}
160+
LIMIT {trip_limit}
161+
),
162+
building_with_geom AS (
163+
SELECT building_id, name, geom as building_geom
164+
FROM {building_table}
165+
),
166+
distances_calculated AS (
167+
SELECT
168+
t.trip_id,
169+
b.building_id,
170+
b.name,
171+
ST_Distance(t.trip_geom, b.building_geom) as distance
172+
FROM trip_sample t
173+
CROSS JOIN building_with_geom b
174+
),
175+
ranked_neighbors AS (
176+
SELECT *,
177+
ROW_NUMBER() OVER (PARTITION BY trip_id ORDER BY distance ASC) as rn
178+
FROM distances_calculated
179+
)
180+
SELECT trip_id, building_id, name, distance
181+
FROM ranked_neighbors
182+
WHERE rn <= {k}
183+
ORDER BY trip_id, distance
184+
"""
118185

119-
# Run the benchmark
120-
result_count = benchmark(run_knn_query)
186+
result = eng.execute_and_collect(query)
187+
return len(result)
121188

122-
# Verify we got the expected number of results (trips * k)
123-
expected_count = trip_limit * k
124-
assert result_count == expected_count, (
125-
f"Expected {expected_count} results, got {result_count}"
126-
)
189+
# Run the benchmark
190+
benchmark(run_knn_query)
127191

128192
@pytest.mark.parametrize("k", [1, 5, 10, 20])
129-
def test_knn_scalability_by_k(self, benchmark, k):
130-
"""Test how KNN performance scales with increasing k values"""
131-
132-
def run_knn_query():
133-
query = f"""
134-
WITH trip_sample AS (
135-
SELECT trip_id, geom as trip_geom
136-
FROM knn_trips_small
137-
LIMIT 50 -- Small sample for k scaling test
138-
)
139-
SELECT
140-
COUNT(*) as result_count
141-
FROM trip_sample t
142-
JOIN knn_buildings_small b ON ST_KNN(t.trip_geom, b.geom, {k}, FALSE)
143-
"""
144-
result = self.sedonadb.execute_and_collect(query)
145-
return result.to_pandas().iloc[0]["result_count"]
146-
147-
result_count = benchmark(run_knn_query)
148-
expected_count = 50 * k # 50 trips * k neighbors each
149-
assert result_count == expected_count, (
150-
f"Expected {expected_count} results, got {result_count}"
151-
)
152-
153-
def test_knn_correctness(self):
154-
"""Verify KNN returns results in correct distance order"""
155-
156-
# Test with a known point and verify ordering
157-
query = """
158-
WITH test_point AS (
159-
SELECT ST_Point(0.0, 0.0) as query_geom
160-
)
161-
SELECT
162-
ST_Distance(test_point.query_geom, b.geom) as distance,
163-
b.building_id
164-
FROM test_point
165-
JOIN knn_buildings_small b ON ST_KNN(test_point.query_geom, b.geom, 5, FALSE)
166-
ORDER BY distance
167-
"""
193+
@pytest.mark.parametrize("engine", [SedonaDB, PostGIS, DuckDB])
194+
def test_knn_scalability_by_k(self, benchmark, k, engine):
195+
"""Test how KNN performance scales with increasing k values - SedonaDB vs PostGIS"""
168196

169-
result = self.sedonadb.execute_and_collect(query).to_pandas()
170-
171-
# Verify we got 5 results
172-
assert len(result) == 5, f"Expected 5 results, got {len(result)}"
173-
174-
# Verify distances are in ascending order
175-
distances = result["distance"].tolist()
176-
assert distances == sorted(distances), (
177-
f"Results not ordered by distance: {distances}"
178-
)
179-
180-
# Verify all distances are non-negative
181-
assert all(d >= 0 for d in distances), f"Found negative distances: {distances}"
182-
183-
def test_knn_tie_breaking(self):
184-
"""Test KNN behavior with tie-breaking when geometries have equal distances"""
185-
186-
# Create test data with known equal distances
187-
setup_query = """
188-
WITH test_points AS (
189-
SELECT 1 as id, ST_Point(1.0, 0.0) as geom
190-
UNION ALL
191-
SELECT 2 as id, ST_Point(-1.0, 0.0) as geom
192-
UNION ALL
193-
SELECT 3 as id, ST_Point(0.0, 1.0) as geom
194-
UNION ALL
195-
SELECT 4 as id, ST_Point(0.0, -1.0) as geom
196-
UNION ALL
197-
SELECT 5 as id, ST_Point(2.0, 0.0) as geom
198-
)
199-
SELECT * FROM test_points
200-
"""
201-
tie_test_tab = self.sedonadb.execute_and_collect(setup_query)
202-
self.sedonadb.create_table_arrow("knn_tie_test", tie_test_tab)
203-
204-
# Query for 2 nearest neighbors from origin - should get 2 of the 4 equidistant points
205-
query = """
206-
WITH query_point AS (
207-
SELECT ST_Point(0.0, 0.0) as geom
208-
)
209-
SELECT
210-
t.id,
211-
ST_Distance(query_point.geom, t.geom) as distance
212-
FROM query_point
213-
JOIN knn_tie_test t ON ST_KNN(query_point.geom, t.geom, 2, FALSE)
214-
ORDER BY distance, t.id
215-
"""
197+
# Get the appropriate engine instance
198+
eng = self._get_eng(engine)
216199

217-
result = self.sedonadb.execute_and_collect(query).to_pandas()
200+
def run_knn_query():
201+
if engine == SedonaDB:
202+
# SedonaDB syntax
203+
query = f"""
204+
WITH trip_sample AS (
205+
SELECT trip_id, geom as trip_geom
206+
FROM knn_trips_small
207+
LIMIT 50 -- Small sample for k scaling test
208+
)
209+
SELECT
210+
COUNT(*) as result_count
211+
FROM trip_sample t
212+
JOIN knn_buildings_small b ON ST_KNN(t.trip_geom, b.geom, {k}, FALSE)
213+
"""
214+
elif engine == PostGIS:
215+
# PostGIS syntax
216+
query = f"""
217+
WITH trip_sample AS (
218+
SELECT trip_id, geom as trip_geom
219+
FROM knn_trips_small
220+
LIMIT 50
221+
),
222+
ranked_neighbors AS (
223+
SELECT
224+
t.trip_id,
225+
ROW_NUMBER() OVER (PARTITION BY t.trip_id ORDER BY t.trip_geom <-> b.geom) as rn
226+
FROM trip_sample t
227+
CROSS JOIN knn_buildings_small b
228+
)
229+
SELECT COUNT(*) as result_count
230+
FROM ranked_neighbors
231+
WHERE rn <= {k}
232+
"""
233+
else: # DuckDB
234+
# DuckDB KNN simulation
235+
query = f"""
236+
WITH trip_sample AS (
237+
SELECT trip_id, geom as trip_geom
238+
FROM knn_trips_small
239+
LIMIT 50
240+
),
241+
ranked_neighbors AS (
242+
SELECT
243+
t.trip_id,
244+
ROW_NUMBER() OVER (PARTITION BY t.trip_id ORDER BY ST_Distance(t.trip_geom, b.geom) ASC) as rn
245+
FROM trip_sample t
246+
CROSS JOIN knn_buildings_small b
247+
)
248+
SELECT COUNT(*) as result_count
249+
FROM ranked_neighbors
250+
WHERE rn <= {k}
251+
"""
218252

219-
# Should get exactly 2 results
220-
assert len(result) == 2, f"Expected 2 results, got {len(result)}"
253+
result = eng.execute_and_collect(query)
254+
return result.to_pandas().iloc[0]["result_count"]
221255

222-
# Both should be at distance 1.0 (the 4 equidistant points)
223-
distances = result["distance"].tolist()
224-
assert all(abs(d - 1.0) < 1e-6 for d in distances), (
225-
f"Expected distances ~1.0, got {distances}"
226-
)
256+
# Run the benchmark
257+
benchmark(run_knn_query)

0 commit comments

Comments
 (0)