@@ -37,66 +37,72 @@ def get_gtfs_feeds_query(
3737 dataset_latitudes : str | None = None ,
3838 dataset_longitudes : str | None = None ,
3939 bounding_filter_method : str | None = None ,
40- is_official : bool = False ,
41- include_wip : bool = False ,
40+ is_official : bool | None = None ,
41+ published_only : bool = True ,
4242 include_options_for_joinedload : bool = True ,
4343) -> Query [any ]:
4444 """Get the DB query to use to retrieve the GTFS feeds.."""
4545 gtfs_feed_filter = GtfsFeedFilter (
4646 stable_id = stable_id ,
4747 provider__ilike = provider ,
4848 producer_url__ilike = producer_url ,
49- location = LocationFilter (
50- country_code = country_code ,
51- subdivision_name__ilike = subdivision_name ,
52- municipality__ilike = municipality ,
53- ),
49+ location = None ,
5450 )
5551
56- subquery = gtfs_feed_filter .filter (select (Gtfsfeed .id ). join ( Location , Gtfsfeed . locations ) )
52+ subquery = gtfs_feed_filter .filter (select (Gtfsfeed .id ))
5753 subquery = apply_bounding_filtering (
5854 subquery , dataset_latitudes , dataset_longitudes , bounding_filter_method
5955 ).subquery ()
60-
6156 feed_query = (
6257 db_session .query (Gtfsfeed )
6358 .outerjoin (Gtfsfeed .gtfsdatasets )
6459 .filter (Gtfsfeed .id .in_ (subquery ))
65- .filter ((Gtfsdataset .latest ) | ( Gtfsdataset .id == None )) # noqa: E711
60+ .filter (or_ (Gtfsdataset .latest , Gtfsdataset .id == None )) # noqa: E711
6661 )
67- if not include_wip :
62+
63+ if country_code or subdivision_name or municipality :
64+ location_filter = LocationFilter (
65+ country_code = country_code ,
66+ subdivision_name__ilike = subdivision_name ,
67+ municipality__ilike = municipality ,
68+ )
69+ location_subquery = location_filter .filter (select (Location .id ))
70+ feed_query = feed_query .filter (Gtfsfeed .locations .any (Location .id .in_ (location_subquery )))
71+
72+ if published_only :
6873 feed_query = feed_query .filter (Gtfsfeed .operational_status == "published" )
6974
75+ feed_query = add_official_filter (feed_query , is_official )
76+
7077 if include_options_for_joinedload :
7178 feed_query = feed_query .options (
7279 contains_eager (Gtfsfeed .gtfsdatasets )
7380 .joinedload (Gtfsdataset .validation_reports )
7481 .joinedload (Validationreport .notices ),
7582 * get_joinedload_options (),
7683 ).order_by (Gtfsfeed .provider , Gtfsfeed .stable_id )
77- if is_official :
78- feed_query = feed_query .filter (Feed .official )
84+
7985 feed_query = feed_query .limit (limit ).offset (offset )
8086 return feed_query
8187
8288
8389def get_all_gtfs_feeds (
8490 db_session : Session ,
85- include_wip : bool = False ,
91+ published_only : bool = True ,
8692 batch_size : int = 250 ,
8793) -> Iterator [Gtfsfeed ]:
8894 """
8995 Fetch all GTFS feeds.
9096
9197 @param db_session: The database session.
92- @param include_wip: Whether to include or exclude WIP feeds.
98+ @param published_only: Include only the published feeds.
9399 @param batch_size: The number of feeds to fetch from the database at a time.
94100 A lower value means less memory but more queries.
95101
96102 @return: The GTFS feeds in an iterator.
97103 """
98104 feed_query = db_session .query (Gtfsfeed ).order_by (Gtfsfeed .stable_id ).yield_per (batch_size )
99- if not include_wip :
105+ if published_only :
100106 feed_query = feed_query .filter (Gtfsfeed .operational_status == "published" )
101107
102108 for batch in batched (feed_query , batch_size ):
@@ -126,7 +132,7 @@ def get_gtfs_rt_feeds_query(
126132 subdivision_name : str | None ,
127133 municipality : str | None ,
128134 is_official : bool | None ,
129- include_wip : bool = False ,
135+ published_only : bool = True ,
130136 db_session : Session = None ,
131137) -> Query :
132138 """Get some (or all) GTFS Realtime feeds from the Mobility Database."""
@@ -160,37 +166,49 @@ def get_gtfs_rt_feeds_query(
160166 ).subquery ()
161167 feed_query = db_session .query (Gtfsrealtimefeed ).filter (Gtfsrealtimefeed .id .in_ (subquery ))
162168
163- if not include_wip :
169+ if published_only :
164170 feed_query = feed_query .filter (Gtfsrealtimefeed .operational_status == "published" )
165171
166172 feed_query = feed_query .options (
167173 joinedload (Gtfsrealtimefeed .entitytypes ),
168174 joinedload (Gtfsrealtimefeed .gtfs_feeds ),
169175 * get_joinedload_options (),
170176 )
171- if is_official :
172- feed_query = feed_query . filter ( Feed . official )
177+ feed_query = add_official_filter ( feed_query , is_official )
178+
173179 feed_query = feed_query .limit (limit ).offset (offset )
174180 return feed_query
175181
176182
183+ def add_official_filter (query : Query , is_official : bool | None ) -> Query :
184+ """
185+ Add the is_official filter to the query if necessary
186+ """
187+ if is_official is not None :
188+ if is_official :
189+ query = query .filter (Feed .official .is_ (True ))
190+ else :
191+ query = query .filter (or_ (Feed .official .is_ (False ), Feed .official .is_ (None )))
192+ return query
193+
194+
177195def get_all_gtfs_rt_feeds (
178196 db_session : Session ,
179- include_wip : bool = False ,
197+ published_only : bool = True ,
180198 batch_size : int = 250 ,
181199) -> Iterator [Gtfsrealtimefeed ]:
182200 """
183201 Fetch all GTFS realtime feeds.
184202
185203 @param db_session: The database session.
186- @param include_wip: Whether to include or exclude WIP feeds.
204+ @param published_only: Include only the published feeds.
187205 @param batch_size: The number of feeds to fetch from the database at a time.
188206 A lower value means less memory but more queries.
189207
190208 @return: The GTFS realtime feeds in an iterator.
191209 """
192210 feed_query = db_session .query (Gtfsrealtimefeed .stable_id ).order_by (Gtfsrealtimefeed .stable_id ).yield_per (batch_size )
193- if not include_wip :
211+ if published_only :
194212 feed_query = feed_query .filter (Gtfsrealtimefeed .operational_status == "published" )
195213
196214 for batch in batched (feed_query , batch_size ):
0 commit comments