Skip to content

Commit 44d2ded

Browse files
committed
fix: downstream tasks tested
1 parent bb321ee commit 44d2ded

File tree

3 files changed

+74
-48
lines changed

3 files changed

+74
-48
lines changed

functions-python/tasks_executor/src/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
"description": "Rebuilds missing visualization files for GTFS datasets.",
7979
"handler": rebuild_missing_visualization_files_handler,
8080
},
81-
"data_import": {
81+
"jbda_import": {
8282
"description": "Imports JBDA data into the system.",
8383
"handler": import_jbda_handler,
8484
},

functions-python/tasks_executor/src/tasks/data_import/import_jbda_feeds.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,6 @@ def _import_jbda(db_session: Session, dry_run: bool = True) -> dict:
369369
dict: Result summary with message and counters.
370370
"""
371371
logger.info("Starting JBDA import dry_run=%s", dry_run)
372-
execution_id = uuid.uuid4()
373-
publisher = pubsub_v1.PublisherClient()
374372
session_http = requests.Session()
375373
try:
376374
res = session_http.get(FEEDS_URL, timeout=REQUEST_TIMEOUT_S)
@@ -391,16 +389,17 @@ def _import_jbda(db_session: Session, dry_run: bool = True) -> dict:
391389
payload = res.json() or {}
392390
feeds_list: List[dict] = payload.get("body") or []
393391
logger.info(
394-
"Commit batch size (env GIT_COMMIT_BATCH_SIZE)=%s",
395-
os.getenv("GIT_COMMIT_BATCH_SIZE", "20"),
392+
"Commit batch size (env COMMIT_BATCH_SIZE)=%s",
393+
os.getenv("COMMIT_BATCH_SIZE", "20"),
396394
)
397395

398396
created_gtfs = 0
399397
updated_gtfs = 0
400398
created_rt = 0
401399
linked_refs = 0
402400
total_processed = 0
403-
commit_batch_size = int(os.getenv("GIT_COMMIT_BATCH_SIZE", 20))
401+
commit_batch_size = int(os.getenv("COMMIT_BATCH_SIZE", 20))
402+
feeds_to_publish: List[Feed] = []
404403

405404
for idx, item in enumerate(feeds_list, start=1):
406405
try:
@@ -516,8 +515,8 @@ def _import_jbda(db_session: Session, dry_run: bool = True) -> dict:
516515
linked_refs += 1
517516

518517
total_processed += 1
519-
if is_new_gtfs:
520-
trigger_dataset_download(gtfs_feed, execution_id, publisher)
518+
if is_new_gtfs and not dry_run:
519+
feeds_to_publish.append(gtfs_feed)
521520

522521
if not dry_run and (total_processed % commit_batch_size == 0):
523522
logger.info("Committing batch at total_processed=%d", total_processed)
@@ -540,6 +539,10 @@ def _import_jbda(db_session: Session, dry_run: bool = True) -> dict:
540539
"Final commit after processing all items (count=%d)", total_processed
541540
)
542541
db_session.commit()
542+
execution_id = str(uuid.uuid4())
543+
publisher = pubsub_v1.PublisherClient()
544+
for feed in feeds_to_publish:
545+
trigger_dataset_download(feed, execution_id, publisher)
543546
except IntegrityError:
544547
db_session.rollback()
545548
logger.exception("Final commit failed with IntegrityError; rolled back")

functions-python/tasks_executor/tests/tasks/data_import/test_jbda_import.py

Lines changed: 63 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import sys
2-
import types
1+
import os
32
import unittest
43
from typing import Any, Dict, List
54
from unittest.mock import patch
@@ -11,7 +10,6 @@
1110
from shared.database_gen.sqlacodegen_models import (
1211
Gtfsfeed,
1312
Gtfsrealtimefeed,
14-
Entitytype,
1513
Feedrelatedlink,
1614
)
1715

@@ -53,7 +51,6 @@ class _FakeSessionOK:
5351
DETAIL_TMPL = "https://api.gtfs-data.jp/v2/organizations/{org_id}/feeds/{feed_id}"
5452

5553
def get(self, url, timeout=60):
56-
# feeds index
5754
if url == self.FEEDS_URL:
5855
return _FakeResponse(
5956
{
@@ -85,7 +82,6 @@ def get(self, url, timeout=60):
8582
}
8683
)
8784

88-
# details for feed1
8985
if url == self.DETAIL_TMPL.format(org_id="org1", feed_id="feed1"):
9086
return _FakeResponse(
9187
{
@@ -114,11 +110,6 @@ def get(self, url, timeout=60):
114110
}
115111
)
116112

117-
# details for feed2 (won't be called, discontinued)
118-
if url == self.DETAIL_TMPL.format(org_id="org2", feed_id="feed2"):
119-
return _FakeResponse({"body": {}}, 404)
120-
121-
# details for feed3 (no gtfs_url -> skipped)
122113
if url == self.DETAIL_TMPL.format(org_id="org3", feed_id="feed3"):
123114
return _FakeResponse(
124115
{
@@ -140,6 +131,33 @@ def get(self, url, timeout=60):
140131
raise RuntimeError("network down")
141132

142133

134+
class _FakeFuture:
135+
def __init__(self):
136+
self._callbacks = []
137+
138+
def add_done_callback(self, cb):
139+
# Call immediately to simulate instant publish
140+
try:
141+
cb(self)
142+
except Exception:
143+
pass
144+
145+
def result(self, timeout=None):
146+
return None # instant ok
147+
148+
149+
class _FakePublisher:
150+
def __init__(self):
151+
self.published = [] # list of tuples (topic_path, data_bytes)
152+
153+
def topic_path(self, project_id, topic_name):
154+
return f"projects/{project_id}/topics/{topic_name}"
155+
156+
def publish(self, topic_path, data: bytes):
157+
self.published.append((topic_path, data))
158+
return _FakeFuture()
159+
160+
143161
class TestHelpers(unittest.TestCase):
144162
def test_choose_gtfs_file(self):
145163
files = [
@@ -169,43 +187,39 @@ def test_get_gtfs_file_url(self):
169187
self.assertIsNone(get_gtfs_file_url(detail, rid="next_2", kind="gtfs_url"))
170188

171189

172-
class _RequestsModule(types.ModuleType):
173-
def __init__(self, session_cls):
174-
super().__init__("requests")
175-
self._session_cls = session_cls
176-
177-
class _SessionWrapper:
178-
def __init__(self, inner):
179-
self._inner = inner
180-
181-
def get(self, *a, **k):
182-
return self._inner.get(*a, **k)
183-
184-
def Session(self):
185-
return self._SessionWrapper(self._session_cls())
190+
# ─────────────────────────────────────────────────────────────────────────────
191+
# Import tests
192+
# ─────────────────────────────────────────────────────────────────────────────
186193

187194

188195
class TestImportJBDA(unittest.TestCase):
189-
def _patch_requests(self, session_cls):
190-
fake_requests = _RequestsModule(session_cls)
191-
return patch.dict(sys.modules, {"requests": fake_requests}, clear=False)
192-
193196
@with_db_session(db_url=default_db_url)
194197
def test_import_creates_gtfs_rt_and_related_links(self, db_session: Session):
198+
fake_pub = _FakePublisher()
199+
195200
with patch(
196201
"tasks.data_import.import_jbda_feeds.requests.Session",
197202
return_value=_FakeSessionOK(),
198-
), patch("tasks.data_import.import_jbda_feeds.REQUEST_TIMEOUT_S", 0.01):
203+
), patch("tasks.data_import.import_jbda_feeds.REQUEST_TIMEOUT_S", 0.01), patch(
204+
"tasks.data_import.import_jbda_feeds.pubsub_v1.PublisherClient",
205+
return_value=fake_pub,
206+
), patch(
207+
"tasks.data_import.data_import_utils.PROJECT_ID", "test-project"
208+
), patch(
209+
"tasks.data_import.data_import_utils.DATASET_BATCH_TOPIC", "dataset-batch"
210+
), patch.dict(
211+
os.environ, {"COMMIT_BATCH_SIZE": "1"}, clear=False
212+
):
199213
result = import_jbda_handler({"dry_run": False})
200214

201-
# Summary checks
215+
# Summary checks (unchanged)
202216
self.assertEqual(
203217
{
204218
"message": "JBDA import executed successfully.",
205219
"created_gtfs": 1,
206220
"updated_gtfs": 0,
207221
"created_rt": 2,
208-
"linked_refs": 2, # one per RT link established (tu + vp)
222+
"linked_refs": 2, # per RT link (tu + vp)
209223
"total_processed_items": 1,
210224
"params": {"dry_run": False},
211225
},
@@ -221,15 +235,15 @@ def test_import_creates_gtfs_rt_and_related_links(self, db_session: Session):
221235
self.assertIsNotNone(sched)
222236
self.assertEqual(sched.feed_name, "Feed One")
223237
self.assertEqual(sched.producer_url, "https://gtfs.example/one.zip")
224-
# Related links (only next_1 exists in detail)
238+
239+
# Related links (only next_1 exists)
225240
links: List[Feedrelatedlink] = list(sched.feedrelatedlinks)
226241
codes = {link.code for link in links}
227242
self.assertIn("next_1", codes)
228-
# URL for next_1 correct
229243
next1 = next(link for link in links if link.code == "next_1")
230244
self.assertEqual(next1.url, "https://gtfs.example/one-next.zip")
231245

232-
# DB checks for RT feeds
246+
# RT feeds + entity types + back-links
233247
tu = (
234248
db_session.query(Gtfsrealtimefeed)
235249
.filter(Gtfsrealtimefeed.stable_id == "jbda-feed1-tu")
@@ -242,19 +256,28 @@ def test_import_creates_gtfs_rt_and_related_links(self, db_session: Session):
242256
)
243257
self.assertIsNotNone(tu)
244258
self.assertIsNotNone(vp)
245-
# Each RT has single entity type and back-link to schedule
246259
self.assertEqual(len(tu.entitytypes), 1)
247260
self.assertEqual(len(vp.entitytypes), 1)
248-
tu_et_name = db_session.query(Entitytype).get(tu.entitytypes[0].name).name
249-
vp_et_name = db_session.query(Entitytype).get(vp.entitytypes[0].name).name
250-
self.assertEqual(tu_et_name, "tu")
251-
self.assertEqual(vp_et_name, "vp")
261+
self.assertEqual(tu.entitytypes[0].name, "tu")
262+
self.assertEqual(vp.entitytypes[0].name, "vp")
252263
self.assertEqual([f.id for f in tu.gtfs_feeds], [sched.id])
253264
self.assertEqual([f.id for f in vp.gtfs_feeds], [sched.id])
254-
# RT producer_url set from RT endpoint
255265
self.assertEqual(tu.producer_url, "https://rt.example/one/tu.pb")
256266
self.assertEqual(vp.producer_url, "https://rt.example/one/vp.pb")
257267

268+
# Pub/Sub was called exactly once (only 1 new GTFS feed)
269+
self.assertEqual(len(fake_pub.published), 1)
270+
topic_path, data_bytes = fake_pub.published[0]
271+
self.assertEqual(topic_path, "projects/test-project/topics/dataset-batch")
272+
# payload sanity
273+
import json
274+
275+
payload = json.loads(data_bytes.decode("utf-8"))
276+
self.assertEqual(payload["feed_stable_id"], "jbda-feed1")
277+
self.assertEqual(payload["producer_url"], "https://gtfs.example/one.zip")
278+
self.assertIsNone(payload["dataset_id"])
279+
self.assertIsNone(payload["dataset_hash"])
280+
258281
@with_db_session(db_url=default_db_url)
259282
def test_import_http_failure_graceful(self, db_session: Session):
260283
with patch(

0 commit comments

Comments
 (0)