diff --git a/.github/workflows/direct_download_urls_test_for_sources.yml b/.github/workflows/direct_download_urls_test_for_sources.yml index 613f3c07e..3e06160aa 100644 --- a/.github/workflows/direct_download_urls_test_for_sources.yml +++ b/.github/workflows/direct_download_urls_test_for_sources.yml @@ -185,8 +185,19 @@ jobs: os.makedirs(os.path.dirname(zip_path), exist_ok=True) try: + # First attempt with SSL verification zip_file_req = requests.get(url, params=params, headers=headers, allow_redirects=True) zip_file_req.raise_for_status() + except requests.exceptions.SSLError as ssl_err: + print(f"{base}: SSL verification failed. Retrying without verification.") + try: + zip_file_req = requests.get(url, params=params, headers=headers, allow_redirects=True, verify=False) + zip_file_req.raise_for_status() + print(f"Warning: SSL verification was disabled for {url}. This is a security risk.") + except Exception as retry_e: + raise Exception( + f"{base}: Exception {retry_e} occurred when downloading the URL {url} with SSL verification disabled.\n" + ) except Exception as e: raise Exception( f"{base}: Exception {e} occurred when downloading the URL {url}.\n" diff --git a/tools/helpers.py b/tools/helpers.py index 2b732e709..c387076cd 100644 --- a/tools/helpers.py +++ b/tools/helpers.py @@ -1,13 +1,16 @@ +import datetime import json import os -import datetime +import uuid +from urllib.parse import urlparse + import gtfs_kit -import requests -from requests.exceptions import RequestException, HTTPError import pandas as pd +import requests from pandas.errors import ParserError +from requests.exceptions import RequestException, HTTPError from unidecode import unidecode -import uuid + from tools.constants import ( STOP_LAT, STOP_LON, @@ -15,9 +18,8 @@ MDB_SOURCE_FILENAME, ZIP, FALLBACK_HEADERS, - ) -from urllib.parse import urlparse + ######################### # I/O FUNCTIONS @@ -75,60 +77,78 @@ def to_csv(path, catalog, columns): catalog.to_csv(path, sep=",", index=False) +def get_fallback_headers(url, original_headers=None): + """Generate browser-like fallback headers for a given URL""" + return { + **FALLBACK_HEADERS, + **(original_headers or {}), + "Referer": f"{urlparse(url).scheme}://{urlparse(url).netloc}/", + "Host": urlparse(url).netloc + } + + def download_dataset(url, authentication_type, api_key_parameter_name=None, api_key_parameter_value=None): """ Downloads a dataset from the given URL using specified authentication mechanisms. The method performs a request to the URL with API key passed as either a query - parameter or a header, based on the chosen authentication type. If the download - fails with certain 403 errors, a fallback request with alternative headers is attempted. - It writes the dataset contents to a temporary file and returns the file path. - - :param url: The dataset's source URL. - :type url: str - :param authentication_type: The type of authentication mechanism to use (e.g., - 1 for parameter-based, 2 for header-based). - :type authentication_type: int - :param api_key_parameter_name: The name of the API key parameter/header. It is - optional if the dataset is publicly accessible or no authentication is - required. - :type api_key_parameter_name: str, optional - :param api_key_parameter_value: The value of the API key to authenticate the - request. It is optional if no authentication is required. - :type api_key_parameter_value: str, optional - :return: The file path where the downloaded dataset is temporarily stored. - :type return: str - :raises RequestException: If all attempts to download the dataset fail. + parameter or a header, based on the chosen authentication type. It implements + adaptive fallback strategies for HTTP 403 errors and SSL certificate errors. """ + file_path = os.path.join(os.getcwd(), str(uuid.uuid4())) - def make_request(url, params=None, headers=None): + params = {api_key_parameter_name: api_key_parameter_value} if authentication_type == 1 else None + headers = {api_key_parameter_name: api_key_parameter_value} if authentication_type == 2 else None + + tried_options = set() + current_headers = headers + verify_ssl = True + + for attempt in range(3): try: - response = requests.get(url, params=params, headers=headers, allow_redirects=True) + response = requests.get( + url, + params=params, + headers=current_headers, + allow_redirects=True, + verify=verify_ssl + ) response.raise_for_status() - return response.content - except HTTPError as e: - return None if e.response.status_code == 403 else RequestException( - f"HTTP error {e} when accessing {url}. A fallback attempt with alternative headers will be made.") - except RequestException as e: - raise RequestException(f"Request failed: {e}") - file_path = os.path.join(os.getcwd(), str(uuid.uuid4())) + if not verify_ssl: + import warnings + warnings.warn( + f"SSL verification was disabled when downloading {url}." + ) - params = {api_key_parameter_name: api_key_parameter_value} if authentication_type == 1 else None - headers = {api_key_parameter_name: api_key_parameter_value} if authentication_type == 2 else None + with open(file_path, "wb") as f: + f.write(response.content) + return file_path + + except requests.exceptions.HTTPError as e: + if e.response.status_code == 403 and "fallback_headers" not in tried_options: + current_headers = get_fallback_headers(url, headers) + tried_options.add("fallback_headers") + continue - zip_file = make_request(url, params, headers) or ( - make_request(url, params, {**FALLBACK_HEADERS, **(headers or {}), - "Referer": f"{urlparse(url).scheme}://{urlparse(url).netloc}/", - "Host": urlparse(url).netloc}) - ) + except requests.exceptions.SSLError: + if "disable_ssl" not in tried_options: + verify_ssl = False + tried_options.add("disable_ssl") + continue - if zip_file is None: - raise RequestException(f"FAILURE! Retry attempts failed for {url}.") + except requests.exceptions.RequestException: + pass - with open(file_path, "wb") as f: - f.write(zip_file) + if "fallback_headers" not in tried_options: + current_headers = get_fallback_headers(url, headers) + tried_options.add("fallback_headers") + elif "disable_ssl" not in tried_options: + verify_ssl = False + tried_options.add("disable_ssl") + else: + break - return file_path + raise requests.exceptions.RequestException(f"FAILURE! All download attempts failed for {url}.") ######################### @@ -137,14 +157,14 @@ def make_request(url, params=None, headers=None): def are_overlapping_boxes( - source_minimum_latitude, - source_maximum_latitude, - source_minimum_longitude, - source_maximum_longitude, - filter_minimum_latitude, - filter_maximum_latitude, - filter_minimum_longitude, - filter_maximum_longitude, + source_minimum_latitude, + source_maximum_latitude, + source_minimum_longitude, + source_maximum_longitude, + filter_minimum_latitude, + filter_maximum_latitude, + filter_minimum_longitude, + filter_maximum_longitude, ): """ Verifies if two boxes are overlapping in two dimensions. @@ -178,7 +198,7 @@ def are_overlapping_boxes( def are_overlapping_edges( - source_minimum, source_maximum, filter_minimum, filter_maximum + source_minimum, source_maximum, filter_minimum, filter_maximum ): """ Verifies if two edges are overlapping in one dimension. @@ -238,7 +258,7 @@ def is_readable(file_path, load_func): def create_latest_url( - country_code, subdivision_name, provider, data_type, mdb_source_id + country_code, subdivision_name, provider, data_type, mdb_source_id ): """ Creates the latest URL for an MDB Source. @@ -270,7 +290,7 @@ def create_latest_url( def create_filename( - country_code, subdivision_name, provider, data_type, mdb_source_id, extension + country_code, subdivision_name, provider, data_type, mdb_source_id, extension ): """ Creates the filename for an MDB Source. @@ -401,9 +421,9 @@ def extract_gtfs_bounding_box(file_path): stops_required_columns = {STOP_LAT, STOP_LON} stops_are_present = ( - stops is not None - and stops_required_columns.issubset(stops.columns) - and not (stops[STOP_LAT].dropna().empty or stops[STOP_LON].dropna().empty) + stops is not None + and stops_required_columns.issubset(stops.columns) + and not (stops[STOP_LAT].dropna().empty or stops[STOP_LON].dropna().empty) ) minimum_latitude = stops[STOP_LAT].dropna().min() if stops_are_present else None diff --git a/tools/tests/test_helpers.py b/tools/tests/test_helpers.py index cadf22386..46643092f 100644 --- a/tools/tests/test_helpers.py +++ b/tools/tests/test_helpers.py @@ -1,5 +1,11 @@ from unittest import TestCase, skip from unittest.mock import patch, Mock + +import pandas as pd +import requests +from freezegun import freeze_time +from requests.exceptions import HTTPError + from tools.helpers import ( are_overlapping_edges, are_overlapping_boxes, @@ -18,9 +24,6 @@ normalize, download_dataset, ) -import pandas as pd -from freezegun import freeze_time -from requests.exceptions import HTTPError class TestVerificationFunctions(TestCase): @@ -396,7 +399,7 @@ def test_to_csv(self): @patch("tools.helpers.os") @patch("tools.helpers.requests.get") def test_download_dataset_auth_type_empty( - self, mock_requests, mock_os, mock_uuid4, mock_open + self, mock_requests, mock_os, mock_uuid4, mock_open ): test_authentication_type = None test_api_key_parameter_name = None @@ -422,7 +425,7 @@ def test_download_dataset_auth_type_empty( @patch("tools.helpers.os") @patch("tools.helpers.requests.get") def test_download_dataset_auth_type_0( - self, mock_requests, mock_os, mock_uuid4, mock_open + self, mock_requests, mock_os, mock_uuid4, mock_open ): test_authentication_type = 0 test_api_key_parameter_name = None @@ -448,7 +451,7 @@ def test_download_dataset_auth_type_0( @patch("tools.helpers.os") @patch("tools.helpers.requests.get") def test_download_dataset_auth_type_1( - self, mock_requests, mock_os, mock_uuid4, mock_open + self, mock_requests, mock_os, mock_uuid4, mock_open ): test_authentication_type = 1 test_api_key_parameter_name = "some_name" @@ -477,7 +480,7 @@ def test_download_dataset_auth_type_1( @patch("tools.helpers.os") @patch("tools.helpers.requests.get") def test_download_dataset_auth_type_2( - self, mock_requests, mock_os, mock_uuid4, mock_open + self, mock_requests, mock_os, mock_uuid4, mock_open ): test_authentication_type = 2 test_api_key_parameter_name = "some_name" @@ -506,7 +509,7 @@ def test_download_dataset_auth_type_2( @patch("tools.helpers.os") @patch("tools.helpers.requests.get") def test_download_dataset_exception( - self, mock_requests, mock_os, mock_uuid4, mock_open + self, mock_requests, mock_os, mock_uuid4, mock_open ): test_authentication_type = None test_api_key_parameter_name = None @@ -527,7 +530,6 @@ def test_download_dataset_exception( @patch("tools.helpers.os") @patch("tools.helpers.requests.get") def test_download_dataset_403_fallback_success(self, mock_requests, mock_os, mock_uuid4, mock_open): - response_403 = Mock(status_code=403) response_403.raise_for_status.side_effect = HTTPError(response=response_403) @@ -537,7 +539,7 @@ def test_download_dataset_403_fallback_success(self, mock_requests, mock_os, moc mock_os.path.join.return_value = self.test_path under_test = download_dataset(url=self.test_url, authentication_type=0, api_key_parameter_name=None, - api_key_parameter_value=None, ) + api_key_parameter_value=None, ) self.assertEqual(under_test, self.test_path) self.assertEqual(mock_requests.call_count, 2) @@ -555,16 +557,53 @@ def test_download_dataset_403_fallback_failure(self, mock_requests, mock_os, moc response_403_1.raise_for_status.side_effect = HTTPError(response=response_403_1) response_403_2 = Mock(status_code=403) response_403_2.raise_for_status.side_effect = HTTPError(response=response_403_2) + response_403_3 = Mock(status_code=403) + response_403_3.raise_for_status.side_effect = HTTPError(response=response_403_3) - mock_requests.side_effect = [response_403_1, response_403_2] + mock_requests.side_effect = [response_403_1, response_403_2, response_403_3] mock_os.path.join.return_value = self.test_path self.assertRaises(RequestException, download_dataset, url=self.test_url, - authentication_type=test_authentication_type, api_key_parameter_name=test_api_key_parameter_name, - api_key_parameter_value=test_api_key_parameter_value, ) + authentication_type=test_authentication_type, + api_key_parameter_name=test_api_key_parameter_name, + api_key_parameter_value=test_api_key_parameter_value) - self.assertEqual(mock_requests.call_count, 2) + self.assertEqual(mock_requests.call_count, 3) mock_os.path.join.assert_called_once() mock_os.getcwd.assert_called_once() mock_uuid4.assert_called_once() mock_open.assert_not_called() + + @patch("tools.helpers.open") + @patch("tools.helpers.uuid.uuid4") + @patch("tools.helpers.os") + @patch("tools.helpers.requests.get") + def test_download_dataset_ssl_error_fallback(self, mock_requests, mock_os, mock_uuid4, mock_open): + test_authentication_type = 0 + test_api_key_parameter_name = None + test_api_key_parameter_value = None + + ssl_error = requests.exceptions.SSLError("SSL Certificate Verification Failed") + + response_200 = Mock(status_code=200, content=b"file_content") + + mock_requests.side_effect = [ssl_error, response_200] + mock_os.path.join.return_value = self.test_path + + under_test = download_dataset( + url=self.test_url, + authentication_type=test_authentication_type, + api_key_parameter_name=test_api_key_parameter_name, + api_key_parameter_value=test_api_key_parameter_value, + ) + + self.assertEqual(under_test, self.test_path) + self.assertEqual(mock_requests.call_count, 2) + + self.assertTrue(mock_requests.call_args_list[0].kwargs["verify"]) + self.assertFalse(mock_requests.call_args_list[1].kwargs["verify"]) + + mock_os.path.join.assert_called_once() + mock_os.getcwd.assert_called_once() + mock_uuid4.assert_called_once() + mock_open.assert_called_once()