1- import sys
2- import types
1+ import os
32import unittest
43from typing import Any , Dict , List
54from unittest .mock import patch
1110from shared .database_gen .sqlacodegen_models import (
1211 Gtfsfeed ,
1312 Gtfsrealtimefeed ,
14- Entitytype ,
1513 Feedrelatedlink ,
1614)
1715
@@ -53,7 +51,6 @@ class _FakeSessionOK:
5351 DETAIL_TMPL = "https://api.gtfs-data.jp/v2/organizations/{org_id}/feeds/{feed_id}"
5452
5553 def get (self , url , timeout = 60 ):
56- # feeds index
5754 if url == self .FEEDS_URL :
5855 return _FakeResponse (
5956 {
@@ -85,7 +82,6 @@ def get(self, url, timeout=60):
8582 }
8683 )
8784
88- # details for feed1
8985 if url == self .DETAIL_TMPL .format (org_id = "org1" , feed_id = "feed1" ):
9086 return _FakeResponse (
9187 {
@@ -114,11 +110,6 @@ def get(self, url, timeout=60):
114110 }
115111 )
116112
117- # details for feed2 (won't be called, discontinued)
118- if url == self .DETAIL_TMPL .format (org_id = "org2" , feed_id = "feed2" ):
119- return _FakeResponse ({"body" : {}}, 404 )
120-
121- # details for feed3 (no gtfs_url -> skipped)
122113 if url == self .DETAIL_TMPL .format (org_id = "org3" , feed_id = "feed3" ):
123114 return _FakeResponse (
124115 {
@@ -140,6 +131,33 @@ def get(self, url, timeout=60):
140131 raise RuntimeError ("network down" )
141132
142133
134+ class _FakeFuture :
135+ def __init__ (self ):
136+ self ._callbacks = []
137+
138+ def add_done_callback (self , cb ):
139+ # Call immediately to simulate instant publish
140+ try :
141+ cb (self )
142+ except Exception :
143+ pass
144+
145+ def result (self , timeout = None ):
146+ return None # instant ok
147+
148+
149+ class _FakePublisher :
150+ def __init__ (self ):
151+ self .published = [] # list of tuples (topic_path, data_bytes)
152+
153+ def topic_path (self , project_id , topic_name ):
154+ return f"projects/{ project_id } /topics/{ topic_name } "
155+
156+ def publish (self , topic_path , data : bytes ):
157+ self .published .append ((topic_path , data ))
158+ return _FakeFuture ()
159+
160+
143161class TestHelpers (unittest .TestCase ):
144162 def test_choose_gtfs_file (self ):
145163 files = [
@@ -169,43 +187,39 @@ def test_get_gtfs_file_url(self):
169187 self .assertIsNone (get_gtfs_file_url (detail , rid = "next_2" , kind = "gtfs_url" ))
170188
171189
172- class _RequestsModule (types .ModuleType ):
173- def __init__ (self , session_cls ):
174- super ().__init__ ("requests" )
175- self ._session_cls = session_cls
176-
177- class _SessionWrapper :
178- def __init__ (self , inner ):
179- self ._inner = inner
180-
181- def get (self , * a , ** k ):
182- return self ._inner .get (* a , ** k )
183-
184- def Session (self ):
185- return self ._SessionWrapper (self ._session_cls ())
190+ # ─────────────────────────────────────────────────────────────────────────────
191+ # Import tests
192+ # ─────────────────────────────────────────────────────────────────────────────
186193
187194
188195class TestImportJBDA (unittest .TestCase ):
189- def _patch_requests (self , session_cls ):
190- fake_requests = _RequestsModule (session_cls )
191- return patch .dict (sys .modules , {"requests" : fake_requests }, clear = False )
192-
193196 @with_db_session (db_url = default_db_url )
194197 def test_import_creates_gtfs_rt_and_related_links (self , db_session : Session ):
198+ fake_pub = _FakePublisher ()
199+
195200 with patch (
196201 "tasks.data_import.import_jbda_feeds.requests.Session" ,
197202 return_value = _FakeSessionOK (),
198- ), patch ("tasks.data_import.import_jbda_feeds.REQUEST_TIMEOUT_S" , 0.01 ):
203+ ), patch ("tasks.data_import.import_jbda_feeds.REQUEST_TIMEOUT_S" , 0.01 ), patch (
204+ "tasks.data_import.import_jbda_feeds.pubsub_v1.PublisherClient" ,
205+ return_value = fake_pub ,
206+ ), patch (
207+ "tasks.data_import.data_import_utils.PROJECT_ID" , "test-project"
208+ ), patch (
209+ "tasks.data_import.data_import_utils.DATASET_BATCH_TOPIC" , "dataset-batch"
210+ ), patch .dict (
211+ os .environ , {"COMMIT_BATCH_SIZE" : "1" }, clear = False
212+ ):
199213 result = import_jbda_handler ({"dry_run" : False })
200214
201- # Summary checks
215+ # Summary checks (unchanged)
202216 self .assertEqual (
203217 {
204218 "message" : "JBDA import executed successfully." ,
205219 "created_gtfs" : 1 ,
206220 "updated_gtfs" : 0 ,
207221 "created_rt" : 2 ,
208- "linked_refs" : 2 , # one per RT link established (tu + vp)
222+ "linked_refs" : 2 , # per RT link (tu + vp)
209223 "total_processed_items" : 1 ,
210224 "params" : {"dry_run" : False },
211225 },
@@ -221,15 +235,15 @@ def test_import_creates_gtfs_rt_and_related_links(self, db_session: Session):
221235 self .assertIsNotNone (sched )
222236 self .assertEqual (sched .feed_name , "Feed One" )
223237 self .assertEqual (sched .producer_url , "https://gtfs.example/one.zip" )
224- # Related links (only next_1 exists in detail)
238+
239+ # Related links (only next_1 exists)
225240 links : List [Feedrelatedlink ] = list (sched .feedrelatedlinks )
226241 codes = {link .code for link in links }
227242 self .assertIn ("next_1" , codes )
228- # URL for next_1 correct
229243 next1 = next (link for link in links if link .code == "next_1" )
230244 self .assertEqual (next1 .url , "https://gtfs.example/one-next.zip" )
231245
232- # DB checks for RT feeds
246+ # RT feeds + entity types + back-links
233247 tu = (
234248 db_session .query (Gtfsrealtimefeed )
235249 .filter (Gtfsrealtimefeed .stable_id == "jbda-feed1-tu" )
@@ -242,19 +256,28 @@ def test_import_creates_gtfs_rt_and_related_links(self, db_session: Session):
242256 )
243257 self .assertIsNotNone (tu )
244258 self .assertIsNotNone (vp )
245- # Each RT has single entity type and back-link to schedule
246259 self .assertEqual (len (tu .entitytypes ), 1 )
247260 self .assertEqual (len (vp .entitytypes ), 1 )
248- tu_et_name = db_session .query (Entitytype ).get (tu .entitytypes [0 ].name ).name
249- vp_et_name = db_session .query (Entitytype ).get (vp .entitytypes [0 ].name ).name
250- self .assertEqual (tu_et_name , "tu" )
251- self .assertEqual (vp_et_name , "vp" )
261+ self .assertEqual (tu .entitytypes [0 ].name , "tu" )
262+ self .assertEqual (vp .entitytypes [0 ].name , "vp" )
252263 self .assertEqual ([f .id for f in tu .gtfs_feeds ], [sched .id ])
253264 self .assertEqual ([f .id for f in vp .gtfs_feeds ], [sched .id ])
254- # RT producer_url set from RT endpoint
255265 self .assertEqual (tu .producer_url , "https://rt.example/one/tu.pb" )
256266 self .assertEqual (vp .producer_url , "https://rt.example/one/vp.pb" )
257267
268+ # Pub/Sub was called exactly once (only 1 new GTFS feed)
269+ self .assertEqual (len (fake_pub .published ), 1 )
270+ topic_path , data_bytes = fake_pub .published [0 ]
271+ self .assertEqual (topic_path , "projects/test-project/topics/dataset-batch" )
272+ # payload sanity
273+ import json
274+
275+ payload = json .loads (data_bytes .decode ("utf-8" ))
276+ self .assertEqual (payload ["feed_stable_id" ], "jbda-feed1" )
277+ self .assertEqual (payload ["producer_url" ], "https://gtfs.example/one.zip" )
278+ self .assertIsNone (payload ["dataset_id" ])
279+ self .assertIsNone (payload ["dataset_hash" ])
280+
258281 @with_db_session (db_url = default_db_url )
259282 def test_import_http_failure_graceful (self , db_session : Session ):
260283 with patch (
0 commit comments