Skip to content

Commit c181316

Browse files
committed
Add return to main function and db_session commit
1 parent 0981bb7 commit c181316

File tree

5 files changed

+148
-49
lines changed

5 files changed

+148
-49
lines changed

functions-python/process_validation_report/tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def pytest_sessionfinish(session, exitstatus):
143143
returning the exit status to the system.
144144
"""
145145
# Cleaned at the beginning instead of the end so we can examine the DB after the test.
146-
# clean_testing_db()
146+
clean_testing_db()
147147

148148

149149
def pytest_unconfigure(config):

functions-python/process_validation_report/tests/test_validation_report.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def test_get_dataset(self, db_session):
8282
)
8383
try:
8484
db_session.add(feed)
85+
db_session.flush()
8586
db_session.add(dataset)
8687
db_session.flush()
8788
returned_dataset = get_dataset(dataset_stable_id, db_session)
@@ -123,6 +124,7 @@ def test_create_validation_report_entities(self, mock_get, db_session):
123124
)
124125
try:
125126
db_session.add(feed)
127+
db_session.flush()
126128
db_session.add(dataset)
127129
db_session.commit()
128130
create_validation_report_entities(feed_stable_id, dataset_stable_id, "1.0")
@@ -347,17 +349,18 @@ def test_create_validation_report_entities_missing_validator_version(
347349
],
348350
},
349351
)
350-
feed_stable_id = faker.word()
351-
dataset_stable_id = faker.word()
352+
feed_stable_id = faker.uuid4()
353+
dataset_stable_id = faker.uuid4()
352354

353355
# Create GTFS Feed
354-
feed = Gtfsfeed(id=faker.word(), data_type="gtfs", stable_id=feed_stable_id)
356+
feed = Gtfsfeed(id=faker.uuid4(), data_type="gtfs", stable_id=feed_stable_id)
355357
# Create a new dataset
356358
dataset = Gtfsdataset(
357-
id=faker.word(), feed_id=feed.id, stable_id=dataset_stable_id, latest=True
359+
id=faker.uuid4(), feed_id=feed.id, stable_id=dataset_stable_id, latest=True
358360
)
359361
try:
360362
db_session.add(feed)
363+
db_session.flush()
361364
db_session.add(dataset)
362365
db_session.commit()
363366
create_validation_report_entities(feed_stable_id, dataset_stable_id, "1.0")

functions-python/tasks_executor/src/tasks/geojson/update_geojson_files_precision.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,19 @@ def _upload_file(bucket, geojson):
7171

7272

7373
@track_metrics(metrics=("time", "memory", "cpu"))
74-
def _update_feed_info(feed, timestamp):
74+
def _update_feed_info(feed: Gtfsfeed, timestamp):
7575
feed.geolocation_file_created_date = timestamp
7676
# find the most recent dataset with bounding box and set the id
7777
if feed.gtfsdatasets and any(d.bounding_box for d in feed.gtfsdatasets):
7878
latest_with_bbox = max(
7979
(d for d in feed.gtfsdatasets if d.bounding_box),
80-
key=lambda d: d.downloaded_date or timestamp,
80+
key=lambda d: d.downloaded_at or timestamp,
81+
)
82+
feed.geolocation_file_dataset_id = latest_with_bbox.id
83+
else:
84+
logging.info(
85+
"No GTFS datasets available with bounding box for feed %s", feed.id
8186
)
82-
feed.geolocation_file_dataset_id = latest_with_bbox.bounding_box.id
8387

8488

8589
@track_metrics(metrics=("time", "memory", "cpu"))
@@ -171,7 +175,9 @@ def update_geojson_files_precision_handler(
171175

172176
geojson = process_geojson(geojson, precision)
173177
if not geojson:
174-
logging.info("No valid GeoJSON features found in %s", file.name)
178+
logging.warning("No valid GeoJSON features found in %s", file.name)
179+
errors.append(feed.stable_id)
180+
continue
175181

176182
# Optionally upload processed geojson
177183
if not dry_run:
@@ -182,7 +188,8 @@ def update_geojson_files_precision_handler(
182188
except Exception as e:
183189
logging.exception("Error processing feed %s: %s", feed.stable_id, e)
184190
errors.append(feed.stable_id)
185-
191+
if not dry_run and processed > 0:
192+
db_session.commit()
186193
summary = {
187194
"total_processed_files": processed,
188195
"errors": errors,
@@ -194,4 +201,4 @@ def update_geojson_files_precision_handler(
194201
},
195202
}
196203
logging.info("update_geojson_files_handler result: %s", summary)
197-
return
204+
return summary

functions-python/tasks_executor/tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def populate_database(db_session: Session | None = None):
3535
"""
3636
# Create 2 GTFS Feeds
3737
feeds = []
38+
# raise NotImplementedError("Implement the function to populate the database.")
3839
now = datetime.now(UTC)
3940
for i in range(2):
4041
feed = Gtfsfeed(
@@ -56,6 +57,7 @@ def populate_database(db_session: Session | None = None):
5657
feed=feed,
5758
stable_id=f"dataset_stable_{i:04d}",
5859
downloaded_at=now - timedelta(days=i),
60+
bounding_box="POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))",
5961
)
6062
db_session.add(dataset)
6163
datasets.append(dataset)

functions-python/tasks_executor/tests/tasks/geojson/test_update_geojson_files_precision.py

Lines changed: 125 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,32 @@
55
import unittest
66
from unittest.mock import patch
77

8+
from sqlalchemy.orm import Session
9+
10+
from shared.database.database import with_db_session
11+
from shared.helpers.src.shared.database_gen.sqlacodegen_models import Gtfsfeed
812
from tasks.geojson.update_geojson_files_precision import (
913
process_geojson,
1014
update_geojson_files_precision_handler,
1115
GEOLOCATION_FILENAME,
1216
)
17+
from test_shared.test_utils.database_utils import default_db_url
1318

1419

1520
class _FakeBlobContext:
16-
def __init__(self, bucket, name):
21+
def __init__(self, bucket, name, blob_exists=True):
1722
self.bucket = bucket
1823
self.name = name
24+
self.blob_exists = blob_exists
1925

2026
def __enter__(self):
2127
return self
2228

23-
def __exit__(self, exc_type, exc, tb):
24-
return False
29+
# def __exit__(self, exc_type, exc, tb):
30+
# return self.exists
2531

2632
def exists(self):
27-
return self.name in self.bucket.initial_blobs
33+
return self.blob_exists
2834

2935
def download_as_text(self):
3036
return self.bucket.initial_blobs[self.name]
@@ -60,15 +66,16 @@ def bucket(self, name):
6066

6167

6268
class FakeStorageModule:
63-
def __init__(self, bucket):
69+
def __init__(self, bucket, blob_exists=True):
6470
self._bucket = bucket
71+
self._blob_exists = blob_exists
6572

6673
def Client(self):
6774
return FakeClient(self._bucket)
6875

6976
# storage.Blob(...) used as a context manager in the handler
7077
def Blob(self, *, bucket, name):
71-
return _FakeBlobContext(bucket, name)
78+
return _FakeBlobContext(bucket, name, self._blob_exists)
7279

7380

7481
class TestUpdateGeojsonFilesPrecision(unittest.TestCase):
@@ -129,8 +136,8 @@ def test_process_geojson_single_feature_and_list_variants(self):
129136
[round(1.23456789, 3), round(2.3456789, 3)],
130137
)
131138

132-
@patch("tasks.geojson.update_geojson_files_precision.query_unprocessed_feeds")
133-
def test_handler_uploads_and_updates_feed_info(self, mock_query):
139+
@with_db_session(db_url=default_db_url)
140+
def test_handler_uploads_and_updates_feed_info(self, db_session: Session):
134141
geo = {
135142
"type": "FeatureCollection",
136143
"features": [
@@ -144,11 +151,12 @@ def test_handler_uploads_and_updates_feed_info(self, mock_query):
144151
}
145152
],
146153
}
147-
feed_stable_id = "feed_123"
154+
testing_feed = db_session.query(Gtfsfeed).limit(1).first()
155+
feed_stable_id = testing_feed.stable_id
148156
blob_name = f"{feed_stable_id}/{GEOLOCATION_FILENAME}"
149157

150158
fake_bucket = FakeBucket(initial_blobs={blob_name: json.dumps(geo)})
151-
fake_storage = FakeStorageModule(fake_bucket)
159+
fake_storage = FakeStorageModule(fake_bucket, blob_exists=True)
152160

153161
# create module objects for google and google.cloud and inject via sys.modules
154162
cloud_mod = types.ModuleType("google.cloud")
@@ -157,32 +165,6 @@ def test_handler_uploads_and_updates_feed_info(self, mock_query):
157165
google_mod = types.ModuleType("google")
158166
google_mod.cloud = cloud_mod
159167

160-
fake_feed = types.SimpleNamespace(
161-
stable_id=feed_stable_id,
162-
geolocation_file_created_date=None,
163-
gtfsdatasets=[
164-
types.SimpleNamespace(
165-
bounding_box=types.SimpleNamespace(id="bbid"), downloaded_date=None
166-
)
167-
],
168-
)
169-
170-
mock_query.return_value = [fake_feed]
171-
172-
# fake db session
173-
class FakeExecResult:
174-
def scalar(self):
175-
return "NOW_TS"
176-
177-
class FakeDBSession:
178-
def execute(self, q):
179-
return FakeExecResult()
180-
181-
def commit(self):
182-
return None
183-
184-
fake_db = FakeDBSession()
185-
186168
payload = {
187169
"bucket_name": "any-bucket",
188170
"dry_run": False,
@@ -193,7 +175,9 @@ def commit(self):
193175
# Inject modules into sys.modules for the duration of the handler call
194176
with patch.dict(sys.modules, {"google.cloud": cloud_mod, "google": google_mod}):
195177
# call wrapped handler to provide fake db_session
196-
update_geojson_files_precision_handler(payload, db_session=fake_db)
178+
result = update_geojson_files_precision_handler(
179+
payload, db_session=db_session
180+
)
197181

198182
# verify upload happened
199183
self.assertIn("geolocation.geojson", fake_bucket.uploaded)
@@ -202,8 +186,111 @@ def commit(self):
202186
coords = uploaded_geo.get("features")[0]["geometry"]["coordinates"]
203187
self.assertEqual(coords, [round(100.1234567, 5), round(0.9876543, 5)])
204188

189+
self.assertEqual(
190+
{
191+
"total_processed_files": 1,
192+
"errors": [],
193+
"not_found_file": 0,
194+
"params": {
195+
"dry_run": False,
196+
"precision": 5,
197+
"limit": 1,
198+
},
199+
},
200+
result,
201+
)
205202
# feed updated
206-
self.assertEqual(fake_feed.geolocation_file_created_date, "NOW_TS")
203+
reloaded_testing_feed = (
204+
db_session.query(Gtfsfeed)
205+
.filter(Gtfsfeed.id.__eq__(testing_feed.id))
206+
.limit(1)
207+
.first()
208+
)
209+
self.assertIsNotNone(reloaded_testing_feed.geolocation_file_dataset_id)
210+
self.assertIsNotNone(reloaded_testing_feed.geolocation_file_created_date)
211+
212+
@with_db_session(db_url=default_db_url)
213+
def test_handler_file_dont_exists(self, db_session: Session):
214+
fake_bucket = FakeBucket(initial_blobs={})
215+
fake_storage = FakeStorageModule(fake_bucket, blob_exists=False)
216+
217+
# create module objects for google and google.cloud and inject via sys.modules
218+
cloud_mod = types.ModuleType("google.cloud")
219+
# 'from google.cloud import storage' in handler will bind 'storage' to this attribute
220+
cloud_mod.storage = fake_storage
221+
google_mod = types.ModuleType("google")
222+
google_mod.cloud = cloud_mod
223+
224+
payload = {
225+
"bucket_name": "any-bucket",
226+
"dry_run": False,
227+
"precision": 5,
228+
"limit": 1,
229+
}
230+
231+
# Inject modules into sys.modules for the duration of the handler call
232+
with patch.dict(sys.modules, {"google.cloud": cloud_mod, "google": google_mod}):
233+
# call wrapped handler to provide fake db_session
234+
result = update_geojson_files_precision_handler(
235+
payload, db_session=db_session
236+
)
237+
self.assertEqual(
238+
{
239+
"total_processed_files": 0,
240+
"errors": [],
241+
"not_found_file": 1,
242+
"params": {
243+
"dry_run": False,
244+
"precision": 5,
245+
"limit": 1,
246+
},
247+
},
248+
result,
249+
)
250+
251+
@with_db_session(db_url=default_db_url)
252+
def test_handler_file_not_valid_file(self, db_session: Session):
253+
geo = "{}"
254+
testing_feed = db_session.query(Gtfsfeed).limit(1).first()
255+
feed_stable_id = testing_feed.stable_id
256+
blob_name = f"{feed_stable_id}/{GEOLOCATION_FILENAME}"
257+
258+
fake_bucket = FakeBucket(initial_blobs={blob_name: geo})
259+
fake_storage = FakeStorageModule(fake_bucket, blob_exists=True)
260+
261+
# create module objects for google and google.cloud and inject via sys.modules
262+
cloud_mod = types.ModuleType("google.cloud")
263+
# 'from google.cloud import storage' in handler will bind 'storage' to this attribute
264+
cloud_mod.storage = fake_storage
265+
google_mod = types.ModuleType("google")
266+
google_mod.cloud = cloud_mod
267+
268+
payload = {
269+
"bucket_name": "any-bucket",
270+
"dry_run": False,
271+
"precision": 5,
272+
"limit": 1,
273+
}
274+
testing_feed = db_session.query(Gtfsfeed).limit(1).first()
275+
# Inject modules into sys.modules for the duration of the handler call
276+
with patch.dict(sys.modules, {"google.cloud": cloud_mod, "google": google_mod}):
277+
# call wrapped handler to provide fake db_session
278+
result = update_geojson_files_precision_handler(
279+
payload, db_session=db_session
280+
)
281+
self.assertEqual(
282+
{
283+
"total_processed_files": 0,
284+
"errors": [testing_feed.stable_id],
285+
"not_found_file": 0,
286+
"params": {
287+
"dry_run": False,
288+
"precision": 5,
289+
"limit": 1,
290+
},
291+
},
292+
result,
293+
)
207294

208295

209296
if __name__ == "__main__":

0 commit comments

Comments
 (0)