diff --git a/tools/constants.py b/tools/constants.py index d1b7ffabd..9269150c8 100644 --- a/tools/constants.py +++ b/tools/constants.py @@ -99,3 +99,20 @@ LOAD_FUNC = "load_func" ZIP = "zip" JSON = "json" + +#browser header +FALLBACK_HEADERS = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/132.0.0.0 Safari/537.36", + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9," + "image/avif,image/webp,image/apng,*/*;q=0.8," + "application/signed-exchange;v=b3;q=0.7", + "Accept-Language": "en-US,en;q=0.9", + "Accept-Encoding": "gzip, deflate, br", + "Accept": "application/zip", + "Connection": "keep-alive", + "Sec-Fetch-Dest": "document", + "Sec-Fetch-Mode": "navigate", + "Sec-Fetch-Site": "same-origin", +} diff --git a/tools/helpers.py b/tools/helpers.py index 1a0791175..26a73b732 100644 --- a/tools/helpers.py +++ b/tools/helpers.py @@ -1,21 +1,27 @@ +import datetime import json import os -import datetime +import uuid +from urllib.parse import urlparse + import gtfs_kit -import requests -from requests.exceptions import RequestException import pandas as pd +import requests from pandas.errors import ParserError +from requests.exceptions import RequestException, HTTPError from unidecode import unidecode -import uuid + from tools.constants import ( STOP_LAT, STOP_LON, MDB_ARCHIVES_LATEST_URL_TEMPLATE, MDB_SOURCE_FILENAME, ZIP, + FALLBACK_HEADERS, + ) + ######################### # I/O FUNCTIONS ######################### @@ -55,7 +61,7 @@ def to_csv(path, catalog, columns): """ Save a catalog to a CSV file. - This function normalizes a catalog, optionally filters it by specified columns, + This function normalizes a catalog, optionally filters it by specified columns, and saves it to a CSV file at the given path. Args: @@ -72,52 +78,72 @@ def to_csv(path, catalog, columns): catalog.to_csv(path, sep=",", index=False) -def download_dataset( - url, authentication_type, api_key_parameter_name, api_key_parameter_value -): +def download_dataset(url, authentication_type, api_key_parameter_name=None, api_key_parameter_value=None): """ - Download a dataset from a given URL with optional authentication. - - This function downloads a dataset from the specified URL using optional - API key authentication and saves it to a file in the current working directory. - - Args: - url (str): The URL of the dataset to download. - authentication_type (int): The type of authentication to use. - 0: No authentication. - 1: API key as a query parameter. - 2: API key as a header. - api_key_parameter_name (str, optional): The name of the API key parameter. - api_key_parameter_value (str, optional): The value of the API key. - - Returns: - str: The path to the downloaded file. - - Raises: - RequestException: If an error occurs during the download process. + Downloads a dataset from the given URL using specified authentication mechanisms. + The method performs a request to the URL with API key passed as either a query + parameter or a header, based on the chosen authentication type. If the download + fails with certain 403 errors, a fallback request with alternative headers is attempted. + It writes the dataset contents to a temporary file and returns the file path. + + :param url: The dataset's source URL. + :type url: str + :param authentication_type: The type of authentication mechanism to use (e.g., + 1 for parameter-based, 2 for header-based). + :type authentication_type: int + :param api_key_parameter_name: The name of the API key parameter/header. It is + optional if the dataset is publicly accessible or no authentication is + required. + :type api_key_parameter_name: str, optional + :param api_key_parameter_value: The value of the API key to authenticate the + request. It is optional if no authentication is required. + :type api_key_parameter_value: str, optional + :return: The file path where the downloaded dataset is temporarily stored. + :type return: str + :raises RequestException: If all attempts to download the dataset fail. """ - file_name = str(uuid.uuid4()) - file_path = os.path.join(os.getcwd(), file_name) - - params = {} - headers = {} - if authentication_type == 1: - params[api_key_parameter_name] = api_key_parameter_value - elif authentication_type == 2: - headers[api_key_parameter_name] = api_key_parameter_value + def make_request(url, params=None, headers=None): + try: + response = requests.get(url, params=params, headers=headers, allow_redirects=True, verify=True) + response.raise_for_status() + return response.content + except requests.exceptions.SSLError as ssl_err: + ca_bundle_path = os.environ.get("SSL_CERT_PATH") + if ca_bundle_path and os.path.exists(ca_bundle_path): + print(f"SSL verification failed. Retrying with custom CA bundle: {ca_bundle_path}") + try: + response = requests.get(url, params=params, headers=headers, allow_redirects=True, + verify=ca_bundle_path) + response.raise_for_status() + return response.content + except Exception as e: + print(f"SSL retry failed: {e}") + return None + else: + print("Custom CA bundle not found. SSL verification failed.") + return None + except HTTPError as e: + return None if e.response.status_code == 403 else RequestException( + f"HTTP error {e} when accessing {url}. Fallback headers will be tried." + ) + except RequestException as e: + raise RequestException(f"Request failed: {e}") + + file_path = os.path.join(os.getcwd(), str(uuid.uuid4())) + + params = {api_key_parameter_name: api_key_parameter_value} if authentication_type == 1 else None + headers = {api_key_parameter_name: api_key_parameter_value} if authentication_type == 2 else None + + zip_file = make_request(url, params, headers) or make_request( + url, + params, + {**FALLBACK_HEADERS, **(headers or {}), "Referer": f"{urlparse(url).scheme}://{urlparse(url).netloc}/"} + ) - try: - zip_file_req = requests.get( - url, params=params, headers=headers, allow_redirects=True - ) - zip_file_req.raise_for_status() - except RequestException as e: - raise RequestException( - f"FAILURE! Exception {e} occurred when downloading URL {url}.\n" - ) + if zip_file is None: + raise RequestException(f"FAILURE! Retry attempts failed for {url}.") - zip_file = zip_file_req.content with open(file_path, "wb") as f: f.write(zip_file) @@ -130,14 +156,14 @@ def download_dataset( def are_overlapping_boxes( - source_minimum_latitude, - source_maximum_latitude, - source_minimum_longitude, - source_maximum_longitude, - filter_minimum_latitude, - filter_maximum_latitude, - filter_minimum_longitude, - filter_maximum_longitude, + source_minimum_latitude, + source_maximum_latitude, + source_minimum_longitude, + source_maximum_longitude, + filter_minimum_latitude, + filter_maximum_latitude, + filter_minimum_longitude, + filter_maximum_longitude, ): """ Verifies if two boxes are overlapping in two dimensions. @@ -171,7 +197,7 @@ def are_overlapping_boxes( def are_overlapping_edges( - source_minimum, source_maximum, filter_minimum, filter_maximum + source_minimum, source_maximum, filter_minimum, filter_maximum ): """ Verifies if two edges are overlapping in one dimension. @@ -186,7 +212,7 @@ def are_overlapping_edges( filter_maximum (float): The maximum coordinate of the filter edge. Returns: - bool: True if the two edges are overlapping, False otherwise. + bool: True if the two edges are overlapping, False otherwise. Returns False if one or more coordinates are None. """ return ( @@ -231,7 +257,7 @@ def is_readable(file_path, load_func): def create_latest_url( - country_code, subdivision_name, provider, data_type, mdb_source_id + country_code, subdivision_name, provider, data_type, mdb_source_id ): """ Creates the latest URL for an MDB Source. @@ -263,7 +289,7 @@ def create_latest_url( def create_filename( - country_code, subdivision_name, provider, data_type, mdb_source_id, extension + country_code, subdivision_name, provider, data_type, mdb_source_id, extension ): """ Creates the filename for an MDB Source. @@ -394,9 +420,9 @@ def extract_gtfs_bounding_box(file_path): stops_required_columns = {STOP_LAT, STOP_LON} stops_are_present = ( - stops is not None - and stops_required_columns.issubset(stops.columns) - and not (stops[STOP_LAT].dropna().empty or stops[STOP_LON].dropna().empty) + stops is not None + and stops_required_columns.issubset(stops.columns) + and not (stops[STOP_LAT].dropna().empty or stops[STOP_LON].dropna().empty) ) minimum_latitude = stops[STOP_LAT].dropna().min() if stops_are_present else None diff --git a/tools/tests/test_helpers.py b/tools/tests/test_helpers.py index 9d51010ab..cadf22386 100644 --- a/tools/tests/test_helpers.py +++ b/tools/tests/test_helpers.py @@ -20,6 +20,7 @@ ) import pandas as pd from freezegun import freeze_time +from requests.exceptions import HTTPError class TestVerificationFunctions(TestCase): @@ -408,8 +409,8 @@ def test_download_dataset_auth_type_empty( api_key_parameter_value=test_api_key_parameter_value, ) self.assertEqual(under_test, self.test_path) - self.assertEqual(mock_requests.call_args.kwargs["params"], {}) - self.assertEqual(mock_requests.call_args.kwargs["headers"], {}) + self.assertEqual(mock_requests.call_args.kwargs["params"], None) + self.assertEqual(mock_requests.call_args.kwargs["headers"], None) mock_requests.assert_called_once() mock_os.path.join.assert_called_once() mock_os.getcwd.assert_called_once() @@ -434,8 +435,8 @@ def test_download_dataset_auth_type_0( api_key_parameter_value=test_api_key_parameter_value, ) self.assertEqual(under_test, self.test_path) - self.assertEqual(mock_requests.call_args.kwargs["params"], {}) - self.assertEqual(mock_requests.call_args.kwargs["headers"], {}) + self.assertEqual(mock_requests.call_args.kwargs["params"], None) + self.assertEqual(mock_requests.call_args.kwargs["headers"], None) mock_requests.assert_called_once() mock_os.path.join.assert_called_once() mock_os.getcwd.assert_called_once() @@ -464,7 +465,7 @@ def test_download_dataset_auth_type_1( mock_requests.call_args.kwargs["params"], {test_api_key_parameter_name: test_api_key_parameter_value}, ) - self.assertEqual(mock_requests.call_args.kwargs["headers"], {}) + self.assertEqual(mock_requests.call_args.kwargs["headers"], None) mock_requests.assert_called_once() mock_os.path.join.assert_called_once() mock_os.getcwd.assert_called_once() @@ -489,7 +490,7 @@ def test_download_dataset_auth_type_2( api_key_parameter_value=test_api_key_parameter_value, ) self.assertEqual(under_test, self.test_path) - self.assertEqual(mock_requests.call_args.kwargs["params"], {}) + self.assertEqual(mock_requests.call_args.kwargs["params"], None) self.assertEqual( mock_requests.call_args.kwargs["headers"], {test_api_key_parameter_name: test_api_key_parameter_value}, @@ -520,3 +521,50 @@ def test_download_dataset_exception( api_key_parameter_name=test_api_key_parameter_name, api_key_parameter_value=test_api_key_parameter_value, ) + + @patch("tools.helpers.open") + @patch("tools.helpers.uuid.uuid4") + @patch("tools.helpers.os") + @patch("tools.helpers.requests.get") + def test_download_dataset_403_fallback_success(self, mock_requests, mock_os, mock_uuid4, mock_open): + + response_403 = Mock(status_code=403) + response_403.raise_for_status.side_effect = HTTPError(response=response_403) + + response_200 = Mock(status_code=200, content=b"file_content") + + mock_requests.side_effect = [response_403, response_200] + mock_os.path.join.return_value = self.test_path + + under_test = download_dataset(url=self.test_url, authentication_type=0, api_key_parameter_name=None, + api_key_parameter_value=None, ) + + self.assertEqual(under_test, self.test_path) + self.assertEqual(mock_requests.call_count, 2) + + @patch("tools.helpers.open") + @patch("tools.helpers.uuid.uuid4") + @patch("tools.helpers.os") + @patch("tools.helpers.requests.get") + def test_download_dataset_403_fallback_failure(self, mock_requests, mock_os, mock_uuid4, mock_open): + test_authentication_type = 0 + test_api_key_parameter_name = None + test_api_key_parameter_value = None + + response_403_1 = Mock(status_code=403) + response_403_1.raise_for_status.side_effect = HTTPError(response=response_403_1) + response_403_2 = Mock(status_code=403) + response_403_2.raise_for_status.side_effect = HTTPError(response=response_403_2) + + mock_requests.side_effect = [response_403_1, response_403_2] + + mock_os.path.join.return_value = self.test_path + self.assertRaises(RequestException, download_dataset, url=self.test_url, + authentication_type=test_authentication_type, api_key_parameter_name=test_api_key_parameter_name, + api_key_parameter_value=test_api_key_parameter_value, ) + + self.assertEqual(mock_requests.call_count, 2) + mock_os.path.join.assert_called_once() + mock_os.getcwd.assert_called_once() + mock_uuid4.assert_called_once() + mock_open.assert_not_called()