Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/direct_download_urls_test_for_sources.yml
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,19 @@ jobs:
os.makedirs(os.path.dirname(zip_path), exist_ok=True)

try:
# First attempt with SSL verification
zip_file_req = requests.get(url, params=params, headers=headers, allow_redirects=True)
zip_file_req.raise_for_status()
except requests.exceptions.SSLError as ssl_err:
print(f"{base}: SSL verification failed. Retrying without verification.")
try:
zip_file_req = requests.get(url, params=params, headers=headers, allow_redirects=True, verify=False)
zip_file_req.raise_for_status()
print(f"Warning: SSL verification was disabled for {url}. This is a security risk.")
except Exception as retry_e:
raise Exception(
f"{base}: Exception {retry_e} occurred when downloading the URL {url} with SSL verification disabled.\n"
)
except Exception as e:
raise Exception(
f"{base}: Exception {e} occurred when downloading the URL {url}.\n"
Expand Down
140 changes: 80 additions & 60 deletions tools/helpers.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import datetime
import json
import os
import datetime
import uuid
from urllib.parse import urlparse

import gtfs_kit
import requests
from requests.exceptions import RequestException, HTTPError
import pandas as pd
import requests
from pandas.errors import ParserError
from requests.exceptions import RequestException, HTTPError
from unidecode import unidecode
import uuid

from tools.constants import (
STOP_LAT,
STOP_LON,
MDB_ARCHIVES_LATEST_URL_TEMPLATE,
MDB_SOURCE_FILENAME,
ZIP,
FALLBACK_HEADERS,

)
from urllib.parse import urlparse


#########################
# I/O FUNCTIONS
Expand Down Expand Up @@ -75,60 +77,78 @@ def to_csv(path, catalog, columns):
catalog.to_csv(path, sep=",", index=False)


def get_fallback_headers(url, original_headers=None):
"""Generate browser-like fallback headers for a given URL"""
return {
**FALLBACK_HEADERS,
**(original_headers or {}),
"Referer": f"{urlparse(url).scheme}://{urlparse(url).netloc}/",
"Host": urlparse(url).netloc
}


def download_dataset(url, authentication_type, api_key_parameter_name=None, api_key_parameter_value=None):
"""
Downloads a dataset from the given URL using specified authentication mechanisms.
The method performs a request to the URL with API key passed as either a query
parameter or a header, based on the chosen authentication type. If the download
fails with certain 403 errors, a fallback request with alternative headers is attempted.
It writes the dataset contents to a temporary file and returns the file path.

:param url: The dataset's source URL.
:type url: str
:param authentication_type: The type of authentication mechanism to use (e.g.,
1 for parameter-based, 2 for header-based).
:type authentication_type: int
:param api_key_parameter_name: The name of the API key parameter/header. It is
optional if the dataset is publicly accessible or no authentication is
required.
:type api_key_parameter_name: str, optional
:param api_key_parameter_value: The value of the API key to authenticate the
request. It is optional if no authentication is required.
:type api_key_parameter_value: str, optional
:return: The file path where the downloaded dataset is temporarily stored.
:type return: str
:raises RequestException: If all attempts to download the dataset fail.
parameter or a header, based on the chosen authentication type. It implements
adaptive fallback strategies for HTTP 403 errors and SSL certificate errors.
"""
file_path = os.path.join(os.getcwd(), str(uuid.uuid4()))

def make_request(url, params=None, headers=None):
params = {api_key_parameter_name: api_key_parameter_value} if authentication_type == 1 else None
headers = {api_key_parameter_name: api_key_parameter_value} if authentication_type == 2 else None

tried_options = set()
current_headers = headers
verify_ssl = True

for attempt in range(3):
try:
response = requests.get(url, params=params, headers=headers, allow_redirects=True)
response = requests.get(
url,
params=params,
headers=current_headers,
allow_redirects=True,
verify=verify_ssl
)
response.raise_for_status()
return response.content
except HTTPError as e:
return None if e.response.status_code == 403 else RequestException(
f"HTTP error {e} when accessing {url}. A fallback attempt with alternative headers will be made.")
except RequestException as e:
raise RequestException(f"Request failed: {e}")

file_path = os.path.join(os.getcwd(), str(uuid.uuid4()))
if not verify_ssl:
import warnings
warnings.warn(
f"SSL verification was disabled when downloading {url}."
)

params = {api_key_parameter_name: api_key_parameter_value} if authentication_type == 1 else None
headers = {api_key_parameter_name: api_key_parameter_value} if authentication_type == 2 else None
with open(file_path, "wb") as f:
f.write(response.content)
return file_path

except requests.exceptions.HTTPError as e:
if e.response.status_code == 403 and "fallback_headers" not in tried_options:
current_headers = get_fallback_headers(url, headers)
tried_options.add("fallback_headers")
continue

zip_file = make_request(url, params, headers) or (
make_request(url, params, {**FALLBACK_HEADERS, **(headers or {}),
"Referer": f"{urlparse(url).scheme}://{urlparse(url).netloc}/",
"Host": urlparse(url).netloc})
)
except requests.exceptions.SSLError:
if "disable_ssl" not in tried_options:
verify_ssl = False
tried_options.add("disable_ssl")
continue

if zip_file is None:
raise RequestException(f"FAILURE! Retry attempts failed for {url}.")
except requests.exceptions.RequestException:
pass

with open(file_path, "wb") as f:
f.write(zip_file)
if "fallback_headers" not in tried_options:
current_headers = get_fallback_headers(url, headers)
tried_options.add("fallback_headers")
elif "disable_ssl" not in tried_options:
verify_ssl = False
tried_options.add("disable_ssl")
else:
break

return file_path
raise requests.exceptions.RequestException(f"FAILURE! All download attempts failed for {url}.")


#########################
Expand All @@ -137,14 +157,14 @@ def make_request(url, params=None, headers=None):


def are_overlapping_boxes(
source_minimum_latitude,
source_maximum_latitude,
source_minimum_longitude,
source_maximum_longitude,
filter_minimum_latitude,
filter_maximum_latitude,
filter_minimum_longitude,
filter_maximum_longitude,
source_minimum_latitude,
source_maximum_latitude,
source_minimum_longitude,
source_maximum_longitude,
filter_minimum_latitude,
filter_maximum_latitude,
filter_minimum_longitude,
filter_maximum_longitude,
):
"""
Verifies if two boxes are overlapping in two dimensions.
Expand Down Expand Up @@ -178,7 +198,7 @@ def are_overlapping_boxes(


def are_overlapping_edges(
source_minimum, source_maximum, filter_minimum, filter_maximum
source_minimum, source_maximum, filter_minimum, filter_maximum
):
"""
Verifies if two edges are overlapping in one dimension.
Expand Down Expand Up @@ -238,7 +258,7 @@ def is_readable(file_path, load_func):


def create_latest_url(
country_code, subdivision_name, provider, data_type, mdb_source_id
country_code, subdivision_name, provider, data_type, mdb_source_id
):
"""
Creates the latest URL for an MDB Source.
Expand Down Expand Up @@ -270,7 +290,7 @@ def create_latest_url(


def create_filename(
country_code, subdivision_name, provider, data_type, mdb_source_id, extension
country_code, subdivision_name, provider, data_type, mdb_source_id, extension
):
"""
Creates the filename for an MDB Source.
Expand Down Expand Up @@ -401,9 +421,9 @@ def extract_gtfs_bounding_box(file_path):

stops_required_columns = {STOP_LAT, STOP_LON}
stops_are_present = (
stops is not None
and stops_required_columns.issubset(stops.columns)
and not (stops[STOP_LAT].dropna().empty or stops[STOP_LON].dropna().empty)
stops is not None
and stops_required_columns.issubset(stops.columns)
and not (stops[STOP_LAT].dropna().empty or stops[STOP_LON].dropna().empty)
)

minimum_latitude = stops[STOP_LAT].dropna().min() if stops_are_present else None
Expand Down
67 changes: 53 additions & 14 deletions tools/tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from unittest import TestCase, skip
from unittest.mock import patch, Mock

import pandas as pd
import requests
from freezegun import freeze_time
from requests.exceptions import HTTPError

from tools.helpers import (
are_overlapping_edges,
are_overlapping_boxes,
Expand All @@ -18,9 +24,6 @@
normalize,
download_dataset,
)
import pandas as pd
from freezegun import freeze_time
from requests.exceptions import HTTPError


class TestVerificationFunctions(TestCase):
Expand Down Expand Up @@ -396,7 +399,7 @@ def test_to_csv(self):
@patch("tools.helpers.os")
@patch("tools.helpers.requests.get")
def test_download_dataset_auth_type_empty(
self, mock_requests, mock_os, mock_uuid4, mock_open
self, mock_requests, mock_os, mock_uuid4, mock_open
):
test_authentication_type = None
test_api_key_parameter_name = None
Expand All @@ -422,7 +425,7 @@ def test_download_dataset_auth_type_empty(
@patch("tools.helpers.os")
@patch("tools.helpers.requests.get")
def test_download_dataset_auth_type_0(
self, mock_requests, mock_os, mock_uuid4, mock_open
self, mock_requests, mock_os, mock_uuid4, mock_open
):
test_authentication_type = 0
test_api_key_parameter_name = None
Expand All @@ -448,7 +451,7 @@ def test_download_dataset_auth_type_0(
@patch("tools.helpers.os")
@patch("tools.helpers.requests.get")
def test_download_dataset_auth_type_1(
self, mock_requests, mock_os, mock_uuid4, mock_open
self, mock_requests, mock_os, mock_uuid4, mock_open
):
test_authentication_type = 1
test_api_key_parameter_name = "some_name"
Expand Down Expand Up @@ -477,7 +480,7 @@ def test_download_dataset_auth_type_1(
@patch("tools.helpers.os")
@patch("tools.helpers.requests.get")
def test_download_dataset_auth_type_2(
self, mock_requests, mock_os, mock_uuid4, mock_open
self, mock_requests, mock_os, mock_uuid4, mock_open
):
test_authentication_type = 2
test_api_key_parameter_name = "some_name"
Expand Down Expand Up @@ -506,7 +509,7 @@ def test_download_dataset_auth_type_2(
@patch("tools.helpers.os")
@patch("tools.helpers.requests.get")
def test_download_dataset_exception(
self, mock_requests, mock_os, mock_uuid4, mock_open
self, mock_requests, mock_os, mock_uuid4, mock_open
):
test_authentication_type = None
test_api_key_parameter_name = None
Expand All @@ -527,7 +530,6 @@ def test_download_dataset_exception(
@patch("tools.helpers.os")
@patch("tools.helpers.requests.get")
def test_download_dataset_403_fallback_success(self, mock_requests, mock_os, mock_uuid4, mock_open):

response_403 = Mock(status_code=403)
response_403.raise_for_status.side_effect = HTTPError(response=response_403)

Expand All @@ -537,7 +539,7 @@ def test_download_dataset_403_fallback_success(self, mock_requests, mock_os, moc
mock_os.path.join.return_value = self.test_path

under_test = download_dataset(url=self.test_url, authentication_type=0, api_key_parameter_name=None,
api_key_parameter_value=None, )
api_key_parameter_value=None, )

self.assertEqual(under_test, self.test_path)
self.assertEqual(mock_requests.call_count, 2)
Expand All @@ -555,16 +557,53 @@ def test_download_dataset_403_fallback_failure(self, mock_requests, mock_os, moc
response_403_1.raise_for_status.side_effect = HTTPError(response=response_403_1)
response_403_2 = Mock(status_code=403)
response_403_2.raise_for_status.side_effect = HTTPError(response=response_403_2)
response_403_3 = Mock(status_code=403)
response_403_3.raise_for_status.side_effect = HTTPError(response=response_403_3)

mock_requests.side_effect = [response_403_1, response_403_2]
mock_requests.side_effect = [response_403_1, response_403_2, response_403_3]

mock_os.path.join.return_value = self.test_path
self.assertRaises(RequestException, download_dataset, url=self.test_url,
authentication_type=test_authentication_type, api_key_parameter_name=test_api_key_parameter_name,
api_key_parameter_value=test_api_key_parameter_value, )
authentication_type=test_authentication_type,
api_key_parameter_name=test_api_key_parameter_name,
api_key_parameter_value=test_api_key_parameter_value)

self.assertEqual(mock_requests.call_count, 2)
self.assertEqual(mock_requests.call_count, 3)
mock_os.path.join.assert_called_once()
mock_os.getcwd.assert_called_once()
mock_uuid4.assert_called_once()
mock_open.assert_not_called()

@patch("tools.helpers.open")
@patch("tools.helpers.uuid.uuid4")
@patch("tools.helpers.os")
@patch("tools.helpers.requests.get")
def test_download_dataset_ssl_error_fallback(self, mock_requests, mock_os, mock_uuid4, mock_open):
test_authentication_type = 0
test_api_key_parameter_name = None
test_api_key_parameter_value = None

ssl_error = requests.exceptions.SSLError("SSL Certificate Verification Failed")

response_200 = Mock(status_code=200, content=b"file_content")

mock_requests.side_effect = [ssl_error, response_200]
mock_os.path.join.return_value = self.test_path

under_test = download_dataset(
url=self.test_url,
authentication_type=test_authentication_type,
api_key_parameter_name=test_api_key_parameter_name,
api_key_parameter_value=test_api_key_parameter_value,
)

self.assertEqual(under_test, self.test_path)
self.assertEqual(mock_requests.call_count, 2)

self.assertTrue(mock_requests.call_args_list[0].kwargs["verify"])
self.assertFalse(mock_requests.call_args_list[1].kwargs["verify"])

mock_os.path.join.assert_called_once()
mock_os.getcwd.assert_called_once()
mock_uuid4.assert_called_once()
mock_open.assert_called_once()