|
| 1 | +from geoalchemy2 import WKTElement |
1 | 2 | from sqlalchemy import select |
2 | 3 | from sqlalchemy.orm import joinedload |
3 | 4 | from sqlalchemy.orm.query import Query |
| 5 | +from sqlalchemy.orm.strategy_options import _AbstractLoad |
4 | 6 |
|
5 | | -from common.common import apply_bounding_filtering, get_joinedload_options |
6 | | -from common.error_handling import raise_internal_http_validation_error |
| 7 | +# from common.common import apply_bounding_filtering, get_joinedload_options |
| 8 | +from common.error_handling import ( |
| 9 | + raise_internal_http_validation_error, |
| 10 | + invalid_bounding_method, |
| 11 | + invalid_bounding_coordinates, |
| 12 | +) |
7 | 13 | from database.database import Database |
8 | 14 | from database_gen.sqlacodegen_models import ( |
9 | 15 | Gtfsdataset, |
|
12 | 18 | Validationreport, |
13 | 19 | Gtfsrealtimefeed, |
14 | 20 | Entitytype, |
| 21 | + Feed, |
15 | 22 | ) |
16 | 23 |
|
17 | 24 | from feeds.filters.gtfs_feed_filter import GtfsFeedFilter, LocationFilter |
@@ -133,3 +140,66 @@ def get_gtfs_rt_feeds_query( |
133 | 140 | .offset(offset) |
134 | 141 | ) |
135 | 142 | return feed_query |
| 143 | + |
| 144 | + |
| 145 | +def get_joinedload_options() -> [_AbstractLoad]: |
| 146 | + """Returns common joinedload options for feeds queries.""" |
| 147 | + return [joinedload(Feed.locations), joinedload(Feed.externalids), joinedload(Feed.redirectingids)] |
| 148 | + |
| 149 | + |
| 150 | +def apply_bounding_filtering( |
| 151 | + query: Query, |
| 152 | + bounding_latitudes: str, |
| 153 | + bounding_longitudes: str, |
| 154 | + bounding_filter_method: str, |
| 155 | +) -> Query: |
| 156 | + """Create a new query based on the bounding parameters.""" |
| 157 | + |
| 158 | + if not bounding_latitudes or not bounding_longitudes or not bounding_filter_method: |
| 159 | + return query |
| 160 | + |
| 161 | + if ( |
| 162 | + len(bounding_latitudes_tokens := bounding_latitudes.split(",")) != 2 |
| 163 | + or len(bounding_longitudes_tokens := bounding_longitudes.split(",")) != 2 |
| 164 | + ): |
| 165 | + raise_internal_http_validation_error( |
| 166 | + invalid_bounding_coordinates.format(bounding_latitudes, bounding_longitudes) |
| 167 | + ) |
| 168 | + min_latitude, max_latitude = bounding_latitudes_tokens |
| 169 | + min_longitude, max_longitude = bounding_longitudes_tokens |
| 170 | + try: |
| 171 | + min_latitude = float(min_latitude) |
| 172 | + max_latitude = float(max_latitude) |
| 173 | + min_longitude = float(min_longitude) |
| 174 | + max_longitude = float(max_longitude) |
| 175 | + except ValueError: |
| 176 | + raise_internal_http_validation_error( |
| 177 | + invalid_bounding_coordinates.format(bounding_latitudes, bounding_longitudes) |
| 178 | + ) |
| 179 | + |
| 180 | + points = [ |
| 181 | + (min_longitude, min_latitude), |
| 182 | + (min_longitude, max_latitude), |
| 183 | + (max_longitude, max_latitude), |
| 184 | + (max_longitude, min_latitude), |
| 185 | + (min_longitude, min_latitude), |
| 186 | + ] |
| 187 | + wkt_polygon = f"POLYGON(({', '.join(f'{lon} {lat}' for lon, lat in points)}))" |
| 188 | + bounding_box = WKTElement( |
| 189 | + wkt_polygon, |
| 190 | + srid=Gtfsdataset.bounding_box.type.srid, |
| 191 | + ) |
| 192 | + |
| 193 | + if bounding_filter_method == "partially_enclosed": |
| 194 | + return query.filter( |
| 195 | + or_( |
| 196 | + Gtfsdataset.bounding_box.ST_Overlaps(bounding_box), |
| 197 | + Gtfsdataset.bounding_box.ST_Contains(bounding_box), |
| 198 | + ) |
| 199 | + ) |
| 200 | + elif bounding_filter_method == "completely_enclosed": |
| 201 | + return query.filter(bounding_box.ST_Covers(Gtfsdataset.bounding_box)) |
| 202 | + elif bounding_filter_method == "disjoint": |
| 203 | + return query.filter(Gtfsdataset.bounding_box.ST_Disjoint(bounding_box)) |
| 204 | + else: |
| 205 | + raise raise_internal_http_validation_error(invalid_bounding_method.format(bounding_filter_method)) |
0 commit comments