Skip to content

Commit 410091d

Browse files
authored
fix: sqlalchemy sync bugs + batch fill (#1134)
1 parent 63ee442 commit 410091d

File tree

4 files changed

+36
-18
lines changed

4 files changed

+36
-18
lines changed

functions-python/reverse_geolocation/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ This HTTP function initiates reverse geolocation for multiple feeds. It accepts
1616

1717
- **`country_codes`** (optional): A comma-separated list of country codes specifying which feeds should be processed.
1818
- If not provided, the function processes feeds from all available countries.
19+
- **`include_only_unprocessed`** (optional): A boolean flag indicating whether to include only feeds that have not been processed yet.
20+
- If set to `true`, only unprocessed feeds will be considered for reverse geolocation.
21+
- If set to `false`, all feeds will be processed, regardless of their processing status.
22+
- Default is `true`.
1923

2024
**Behavior:**
2125
The function publishes a message to the `reverse-geolocation` Pub/Sub topic for each non deprecated feed that matches the specified country codes.

functions-python/reverse_geolocation/src/reverse_geolocation_batch.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717

1818
@with_db_session
19-
def get_feeds_data(country_codes: List[str], db_session: Session) -> List[Dict]:
19+
def get_feeds_data(
20+
country_codes: List[str], include_only_unprocessed: bool, db_session: Session
21+
) -> List[Dict]:
2022
"""Get the feeds data for the given country codes. In case no country codes are provided, fetch feeds for all
2123
countries."""
2224
query = (
@@ -35,6 +37,10 @@ def get_feeds_data(country_codes: List[str], db_session: Session) -> List[Dict]:
3537
else:
3638
logging.warning("No country codes provided. Fetching feeds for all countries.")
3739

40+
if include_only_unprocessed:
41+
logging.info("Filtering for unprocessed feeds.")
42+
query = query.filter(~Gtfsfeed.feedlocationgrouppoints.any())
43+
3844
results = query.populate_existing().all()
3945
logging.info(f"Found {len(results)} feeds.")
4046

@@ -49,24 +55,28 @@ def get_feeds_data(country_codes: List[str], db_session: Session) -> List[Dict]:
4955
return data
5056

5157

52-
def parse_request_parameters(request: flask.Request) -> List[str]:
53-
"""Parse the request parameters to get the country codes."""
54-
country_codes = request.args.get("country_codes", "").split(",")
58+
def parse_request_parameters(request: flask.Request) -> Tuple[List[str], bool]:
59+
"""Parse the request parameters to get the country codes and whether to include only unprocessed feeds."""
60+
json_request = request.get_json()
61+
country_codes = json_request.get("country_codes", "").split(",")
5562
country_codes = [code.strip().upper() for code in country_codes if code]
5663

5764
# Validate country codes
5865
for country_code in country_codes:
5966
if not pycountry.countries.get(alpha_2=country_code):
6067
raise ValueError(f"Invalid country code: {country_code}")
61-
return country_codes
68+
include_only_unprocessed = (
69+
json_request.get("include_only_unprocessed", True) is True
70+
)
71+
return country_codes, include_only_unprocessed
6272

6373

6474
def reverse_geolocation_batch(request: flask.Request) -> Tuple[str, int]:
6575
"""Batch function to trigger reverse geolocation for feeds."""
6676
try:
6777
Logger.init_logger()
68-
country_codes = parse_request_parameters(request)
69-
feeds_data = get_feeds_data(country_codes)
78+
country_codes, include_only_unprocessed = parse_request_parameters(request)
79+
feeds_data = get_feeds_data(country_codes, include_only_unprocessed)
7080
logging.info(f"Valid feeds with latest dataset: {len(feeds_data)}")
7181

7282
pubsub_topic_name = os.getenv("PUBSUB_TOPIC_NAME", None)

functions-python/reverse_geolocation/src/reverse_geolocation_processor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def extract_location_aggregate(
215215
stop = Feedlocationgrouppoint(feed_id=feed_id, geometry=stop_point)
216216
stop.group = group
217217
db_session.add(stop)
218+
db_session.flush()
218219
logging.info(
219220
f"Point {stop_point} matched to {', '.join([g.name for g in geopolygons])}"
220221
)
@@ -359,11 +360,13 @@ def extract_location_aggregates(
359360
)
360361
for location_group in location_aggregates.values()
361362
]
362-
gtfs_feed.feedosmlocationgroups = osm_location_groups
363+
gtfs_feed.feedosmlocationgroups.clear()
364+
gtfs_feed.feedosmlocationgroups.extend(osm_location_groups)
365+
363366
for gtfs_rt_feed in gtfs_feed.gtfs_rt_feeds:
364367
logging.info(f"Updating GTFS-RT feed with stable ID {gtfs_rt_feed.stable_id}")
365368
gtfs_rt_feed.feedosmlocationgroups.clear()
366-
gtfs_rt_feed.feedosmlocationgroups = osm_location_groups
369+
gtfs_rt_feed.feedosmlocationgroups.extend(osm_location_groups)
367370

368371
feed_locations = []
369372
for location_aggregate in location_aggregates.values():

functions-python/reverse_geolocation/tests/test_reverse_geolocation_batch.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,16 @@ def test_get_feed_data(self, db_session):
5656
db_session.add(feed_2)
5757
db_session.commit()
5858

59-
results = get_feeds_data(["CA"])
59+
results = get_feeds_data(["CA"], True)
6060
self.assertEqual(len(results), 1)
6161
self.assertEqual(results[0]["stable_id"], "test_feed")
6262
self.assertEqual(results[0]["dataset_id"], "test_dataset_latest")
6363
self.assertEqual(results[0]["url"], "test_url")
6464

65-
results_2 = get_feeds_data([])
65+
results_2 = get_feeds_data([], False)
6666
self.assertEqual(len(results_2), 2)
6767

68-
results_3 = get_feeds_data(["US"])
68+
results_3 = get_feeds_data(["US"], True)
6969
self.assertEqual(len(results_3), 1)
7070
self.assertEqual(results_3[0]["stable_id"], "test_feed_2")
7171
self.assertEqual(results_3[0]["dataset_id"], "test_dataset_3")
@@ -75,16 +75,17 @@ def test_get_feed_data(self, db_session):
7575

7676
def test_parse_request_parameters(self):
7777
request = MagicMock()
78-
request.args.get = lambda value, default: {"country_codes": "Ca , uS"}.get(
79-
value, default
80-
)
78+
request.get_json.return_value.get = lambda value, default: {
79+
"country_codes": "Ca , uS"
80+
}.get(value, default)
8181
from reverse_geolocation_batch import parse_request_parameters
8282

83-
country_codes = parse_request_parameters(request)
84-
self.assertEqual(country_codes, ["CA", "US"])
83+
country_codes, include_only_unprocessed = parse_request_parameters(request)
84+
self.assertEqual(["CA", "US"], country_codes)
85+
self.assertTrue(include_only_unprocessed)
8586

8687
with pytest.raises(ValueError):
87-
request.args.get = lambda value, default: {
88+
request.get_json.return_value.get = lambda value, default: {
8889
"country_codes": "CA , US, XX"
8990
}.get(value, default)
9091
parse_request_parameters(request)

0 commit comments

Comments
 (0)