Skip to content

Commit ee644de

Browse files
authored
fix: cartesian product with gtfsdatasets in query (#1261)
1 parent 963c7a2 commit ee644de

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

api/src/shared/common/db_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ def apply_bounding_filtering(
359359
wkt_polygon,
360360
srid=Gtfsdataset.bounding_box.type.srid,
361361
)
362+
query = query.join(Gtfsdataset, Gtfsdataset.feed_id == Gtfsfeed.id)
362363

363364
if bounding_filter_method == "partially_enclosed":
364365
return query.filter(

api/tests/integration/test_database.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
from typing import Final
33

44
import pytest
5+
from faker import Faker
56
from sqlalchemy.orm import Query
67

7-
from shared.common.db_utils import apply_bounding_filtering
8-
from shared.database.database import Database, generate_unique_id
9-
from shared.database_gen.sqlacodegen_models import Feature, Gtfsdataset
108
from feeds.impl.datasets_api_impl import DatasetsApiImpl
119
from feeds.impl.feeds_api_impl import FeedsApiImpl
12-
from faker import Faker
13-
10+
from shared.common.db_utils import apply_bounding_filtering
11+
from shared.database.database import Database, generate_unique_id
12+
from shared.database_gen.sqlacodegen_models import Feature, Gtfsfeed
1413
from tests.test_utils.database import TEST_GTFS_FEED_STABLE_IDS, TEST_DATASET_STABLE_IDS
1514

1615
VALIDATION_ERROR_NOTICES = 7
@@ -23,9 +22,7 @@
2322
VALIDATION_ERROR_COUNT_PER_NOTICE = 2
2423

2524

26-
BASE_QUERY = Query([Gtfsdataset, Gtfsdataset.bounding_box.ST_AsGeoJSON()]).filter(
27-
Gtfsdataset.stable_id == TEST_DATASET_STABLE_IDS[0]
28-
)
25+
BASE_QUERY = Query([Gtfsfeed])
2926
fake = Faker()
3027

3128

@@ -41,7 +38,8 @@ def test_bounding_box_dateset_exists(test_database):
4138
def assert_bounding_box_found(latitudes, longitudes, method, expected_found, test_database):
4239
with test_database.start_db_session() as session:
4340
query = apply_bounding_filtering(BASE_QUERY, latitudes, longitudes, method)
44-
result = test_database.select(session, query=query)
41+
assert query is not None, "apply_bounding_filtering returned None"
42+
result = session.execute(query).all()
4543
assert (len(result) > 0) is expected_found
4644

4745

0 commit comments

Comments
 (0)