Skip to content

Commit 1ecf105

Browse files
authored
Merge pull request #198 from Clinical-Genomics/replace-custom-oauth-header (patch)
### Changed - Replace custom oauth header
2 parents ff3549e + f1c8750 commit 1ecf105

File tree

3 files changed

+108
-98
lines changed

3 files changed

+108
-98
lines changed

microSALT/utils/pubmlst/authentication.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,33 @@
11
import json
22
import os
33
from datetime import datetime, timedelta
4+
45
from dateutil import parser
56
from rauth import OAuth1Session
7+
68
from microSALT import logger
7-
from microSALT.utils.pubmlst.helpers import BASE_API, save_session_token, load_auth_credentials, get_path, folders_config, credentials_path_key, pubmlst_session_credentials_file_name
89
from microSALT.utils.pubmlst.exceptions import (
910
PUBMLSTError,
1011
SessionTokenRequestError,
1112
SessionTokenResponseError,
1213
)
14+
from microSALT.utils.pubmlst.helpers import (
15+
BASE_API,
16+
credentials_path_key,
17+
folders_config,
18+
get_path,
19+
load_auth_credentials,
20+
pubmlst_session_credentials_file_name,
21+
save_session_token,
22+
)
1323

1424
session_token_validity = 12 # 12-hour validity
1525
session_expiration_buffer = 60 # 60-second buffer
1626

27+
1728
def get_new_session_token(db: str):
1829
"""Request a new session token using all credentials for a specific database."""
19-
logger.debug("Fetching a new session token for database '{db}'...")
30+
logger.debug(f"Fetching a new session token for database '{db}'...")
2031

2132
try:
2233
consumer_key, consumer_secret, access_token, access_secret = load_auth_credentials()
@@ -30,8 +41,8 @@ def get_new_session_token(db: str):
3041
access_token_secret=access_secret,
3142
)
3243

33-
response = session.get(url, headers={"User-Agent": "BIGSdb downloader"})
34-
logger.debug("Response Status Code: {status_code}")
44+
response = session.get(url, headers={"User-Agent": "BIGSdb API downloader"})
45+
logger.debug(f"Response Status Code: {response.status_code}")
3546

3647
if response.ok:
3748
try:
@@ -52,23 +63,23 @@ def get_new_session_token(db: str):
5263
except (ValueError, KeyError) as e:
5364
raise SessionTokenResponseError(db, f"Invalid response format: {str(e)}")
5465
else:
55-
raise SessionTokenRequestError(
56-
db, response.status_code, response.text
57-
)
66+
raise SessionTokenRequestError(db, response.status_code, response.text)
5867

5968
except PUBMLSTError as e:
6069
logger.error(f"Error during token fetching: {e}")
6170
raise
6271
except Exception as e:
6372
logger.error(f"Unexpected error: {e}")
64-
raise PUBMLSTError(f"Unexpected error while fetching session token for database '{db}': {e}")
73+
raise PUBMLSTError(
74+
f"Unexpected error while fetching session token for database '{db}': {e}"
75+
)
76+
6577

6678
def load_session_credentials(db: str):
6779
"""Load session token from file for a specific database."""
6880
try:
6981
credentials_file = os.path.join(
70-
get_path(folders_config, credentials_path_key),
71-
pubmlst_session_credentials_file_name
82+
get_path(folders_config, credentials_path_key), pubmlst_session_credentials_file_name
7283
)
7384

7485
if not os.path.exists(credentials_file):
@@ -83,7 +94,9 @@ def load_session_credentials(db: str):
8394

8495
db_session_data = all_sessions.get("databases", {}).get(db)
8596
if not db_session_data:
86-
logger.debug(f"No session token found for database '{db}'. Fetching a new session token.")
97+
logger.debug(
98+
f"No session token found for database '{db}'. Fetching a new session token."
99+
)
87100
return get_new_session_token(db)
88101

89102
expiration = parser.parse(db_session_data.get("expiration", ""))
@@ -94,7 +107,9 @@ def load_session_credentials(db: str):
94107

95108
return session_token, session_secret
96109

97-
logger.debug(f"Session token for database '{db}' has expired. Fetching a new session token.")
110+
logger.debug(
111+
f"Session token for database '{db}' has expired. Fetching a new session token."
112+
)
98113
return get_new_session_token(db)
99114

100115
except PUBMLSTError as e:
@@ -103,4 +118,3 @@ def load_session_credentials(db: str):
103118
except Exception as e:
104119
logger.error(f"Unexpected error: {e}")
105120
raise PUBMLSTError(f"Unexpected error while loading session token for database '{db}': {e}")
106-

microSALT/utils/pubmlst/client.py

Lines changed: 63 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,80 @@
1-
import requests
21
from urllib.parse import urlencode
2+
3+
import requests
4+
from rauth import OAuth1Session
5+
6+
from microSALT import logger
7+
from microSALT.utils.pubmlst.authentication import load_session_credentials
8+
from microSALT.utils.pubmlst.constants import HTTPMethod, RequestType, ResponseHandler
9+
from microSALT.utils.pubmlst.exceptions import PUBMLSTError, SessionTokenRequestError
310
from microSALT.utils.pubmlst.helpers import (
411
BASE_API,
5-
generate_oauth_header,
612
load_auth_credentials,
7-
parse_pubmlst_url
13+
parse_pubmlst_url,
814
)
9-
from microSALT.utils.pubmlst.constants import RequestType, HTTPMethod, ResponseHandler
10-
from microSALT.utils.pubmlst.exceptions import PUBMLSTError, SessionTokenRequestError
11-
from microSALT.utils.pubmlst.authentication import load_session_credentials
12-
from microSALT import logger
15+
1316

1417
class PubMLSTClient:
1518
"""Client for interacting with the PubMLST authenticated API."""
1619

1720
def __init__(self):
1821
"""Initialize the PubMLST client."""
1922
try:
20-
self.consumer_key, self.consumer_secret, self.access_token, self.access_secret = load_auth_credentials()
23+
self.consumer_key, self.consumer_secret, self.access_token, self.access_secret = (
24+
load_auth_credentials()
25+
)
2126
self.database = "pubmlst_test_seqdef"
2227
self.session_token, self.session_secret = load_session_credentials(self.database)
2328
except PUBMLSTError as e:
2429
logger.error(f"Failed to initialize PubMLST client: {e}")
2530
raise
2631

27-
2832
@staticmethod
2933
def parse_pubmlst_url(url: str):
3034
"""
3135
Wrapper for the parse_pubmlst_url function.
3236
"""
3337
return parse_pubmlst_url(url)
3438

35-
36-
def _make_request(self, request_type: RequestType, method: HTTPMethod, url: str, db: str = None, response_handler: ResponseHandler = ResponseHandler.JSON):
37-
""" Handle API requests."""
39+
def _make_request(
40+
self,
41+
request_type: RequestType,
42+
method: HTTPMethod,
43+
url: str,
44+
db: str = None,
45+
response_handler: ResponseHandler = ResponseHandler.JSON,
46+
):
47+
"""Handle API requests."""
3848
try:
39-
if db:
40-
session_token, session_secret = load_session_credentials(db)
41-
else:
42-
session_token, session_secret = self.session_token, self.session_secret
43-
4449
if request_type == RequestType.AUTH:
45-
headers = {
46-
"Authorization": generate_oauth_header(url, self.consumer_key, self.consumer_secret, self.access_token, self.access_secret)
47-
}
50+
access_token = self.access_token
51+
access_secret = self.access_secret
52+
log_database = "authentication"
4853
elif request_type == RequestType.DB:
49-
headers = {
50-
"Authorization": generate_oauth_header(url, self.consumer_key, self.consumer_secret, session_token, session_secret)
51-
}
54+
access_token, access_secret = load_session_credentials(db or self.database)
55+
log_database = db or self.database
5256
else:
5357
raise ValueError(f"Unsupported request type: {request_type}")
5458

55-
if method == HTTPMethod.GET:
56-
response = requests.get(url, headers=headers)
57-
elif method == HTTPMethod.POST:
58-
response = requests.post(url, headers=headers)
59-
elif method == HTTPMethod.PUT:
60-
response = requests.put(url, headers=headers)
61-
else:
62-
raise ValueError(f"Unsupported HTTP method: {method}")
63-
59+
logger.info(f"Making request to {url} for database {log_database}")
60+
logger.info(f"Using session token: {access_token[:6]}****")
61+
62+
session = OAuth1Session(
63+
consumer_key=self.consumer_key,
64+
consumer_secret=self.consumer_secret,
65+
access_token=access_token,
66+
access_token_secret=access_secret,
67+
)
68+
69+
headers = {"User-Agent": "BIGSdb API Client"}
70+
response = session.request(method.value, url, headers=headers)
71+
72+
if response.status_code == 401:
73+
logger.error(f"401 Unauthorized: {response.text}")
74+
6475
response.raise_for_status()
65-
76+
77+
# Process the response based on the requested handler type
6678
if response_handler == ResponseHandler.CONTENT:
6779
return response.content
6880
elif response_handler == ResponseHandler.TEXT:
@@ -73,44 +85,51 @@ def _make_request(self, request_type: RequestType, method: HTTPMethod, url: str,
7385
raise ValueError(f"Unsupported response handler: {response_handler}")
7486

7587
except requests.exceptions.HTTPError as e:
76-
raise SessionTokenRequestError(db or self.database, e.response.status_code, e.response.text) from e
88+
raise SessionTokenRequestError(
89+
db or self.database, e.response.status_code, e.response.text
90+
) from e
7791
except requests.exceptions.RequestException as e:
7892
logger.error(f"Request failed: {e}")
7993
raise PUBMLSTError(f"Request failed: {e}") from e
8094
except Exception as e:
8195
logger.error(f"Unexpected error during request: {e}")
8296
raise PUBMLSTError(f"An unexpected error occurred: {e}") from e
8397

84-
8598
def query_databases(self):
8699
"""Query available PubMLST databases."""
87100
url = f"{BASE_API}/db"
88-
return self._make_request(RequestType.DB, HTTPMethod.GET, url, response_handler=ResponseHandler.JSON)
89-
101+
return self._make_request(
102+
RequestType.DB, HTTPMethod.GET, url, response_handler=ResponseHandler.JSON
103+
)
90104

91105
def download_locus(self, db: str, locus: str, **kwargs):
92106
"""Download locus sequence files."""
93107
base_url = f"{BASE_API}/db/{db}/loci/{locus}/alleles_fasta"
94108
query_string = urlencode(kwargs)
95109
url = f"{base_url}?{query_string}" if query_string else base_url
96-
return self._make_request(RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.TEXT)
97-
110+
return self._make_request(
111+
RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.TEXT
112+
)
98113

99114
def download_profiles_csv(self, db: str, scheme_id: int):
100115
"""Download MLST profiles in CSV format."""
101116
if not scheme_id:
102117
raise ValueError("Scheme ID is required to download profiles CSV.")
103118
url = f"{BASE_API}/db/{db}/schemes/{scheme_id}/profiles_csv"
104-
return self._make_request(RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.TEXT)
105-
119+
return self._make_request(
120+
RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.TEXT
121+
)
106122

107123
def retrieve_scheme_info(self, db: str, scheme_id: int):
108124
"""Retrieve information about a specific MLST scheme."""
109125
url = f"{BASE_API}/db/{db}/schemes/{scheme_id}"
110-
return self._make_request(RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.JSON)
111-
126+
return self._make_request(
127+
RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.JSON
128+
)
112129

113130
def list_schemes(self, db: str):
114131
"""List available MLST schemes for a specific database."""
115132
url = f"{BASE_API}/db/{db}/schemes"
116-
return self._make_request(RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.JSON)
133+
return self._make_request(
134+
RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.JSON
135+
)

microSALT/utils/pubmlst/helpers.py

Lines changed: 18 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
1-
import os
21
import base64
32
import hashlib
4-
import json
53
import hmac
4+
import json
5+
import os
66
import time
77
from pathlib import Path
88
from urllib.parse import quote_plus, urlencode
9+
910
from werkzeug.exceptions import NotFound
11+
1012
from microSALT import app, logger
11-
from microSALT.utils.pubmlst.exceptions import PUBMLSTError, PathResolutionError, CredentialsFileNotFound, InvalidCredentials, SaveSessionError, InvalidURLError
1213
from microSALT.utils.pubmlst.constants import Encoding, url_map
14+
from microSALT.utils.pubmlst.exceptions import (
15+
CredentialsFileNotFound,
16+
InvalidCredentials,
17+
InvalidURLError,
18+
PathResolutionError,
19+
PUBMLSTError,
20+
SaveSessionError,
21+
)
1322

1423
BASE_WEB = "https://pubmlst.org/bigsdb"
1524
BASE_API = "https://rest.pubmlst.org"
@@ -21,6 +30,7 @@
2130
pubmlst_config = app.config["pubmlst"]
2231
folders_config = app.config["folders"]
2332

33+
2434
def get_path(config, config_key: str):
2535
"""Get and expand the file path from the configuration."""
2636
try:
@@ -41,8 +51,7 @@ def load_auth_credentials():
4151
"""Load client ID, client secret, access token, and access secret from credentials file."""
4252
try:
4353
credentials_file = os.path.join(
44-
get_path(folders_config, credentials_path_key),
45-
pubmlst_auth_credentials_file_name
54+
get_path(folders_config, credentials_path_key), pubmlst_auth_credentials_file_name
4655
)
4756

4857
if not os.path.exists(credentials_file):
@@ -81,37 +90,7 @@ def load_auth_credentials():
8190
raise
8291
except Exception as e:
8392
raise PUBMLSTError("An unexpected error occurred while loading credentials: {e}")
84-
85-
86-
def generate_oauth_header(url: str, oauth_consumer_key: str, oauth_consumer_secret: str, oauth_token: str, oauth_token_secret: str):
87-
"""Generate the OAuth1 Authorization header."""
88-
oauth_timestamp = str(int(time.time()))
89-
oauth_nonce = base64.urlsafe_b64encode(os.urandom(32)).decode(Encoding.UTF8.value).strip("=")
90-
oauth_signature_method = "HMAC-SHA1"
91-
oauth_version = "1.0"
92-
93-
oauth_params = {
94-
"oauth_consumer_key": oauth_consumer_key,
95-
"oauth_token": oauth_token,
96-
"oauth_signature_method": oauth_signature_method,
97-
"oauth_timestamp": oauth_timestamp,
98-
"oauth_nonce": oauth_nonce,
99-
"oauth_version": oauth_version,
100-
}
101-
102-
params_encoded = urlencode(sorted(oauth_params.items()))
103-
base_string = f"GET&{quote_plus(url)}&{quote_plus(params_encoded)}"
104-
signing_key = f"{oauth_consumer_secret}&{oauth_token_secret}"
105-
106-
hashed = hmac.new(signing_key.encode(Encoding.UTF8.value), base_string.encode(Encoding.UTF8.value), hashlib.sha1)
107-
oauth_signature = base64.b64encode(hashed.digest()).decode(Encoding.UTF8.value)
108-
109-
oauth_params["oauth_signature"] = oauth_signature
110-
111-
auth_header = "OAuth " + ", ".join(
112-
[f'{quote_plus(k)}="{quote_plus(v)}"' for k, v in oauth_params.items()]
113-
)
114-
return auth_header
93+
11594

11695
def save_session_token(db: str, token: str, secret: str, expiration_date: str):
11796
"""Save session token, secret, and expiration to a JSON file for the specified database."""
@@ -123,8 +102,7 @@ def save_session_token(db: str, token: str, secret: str, expiration_date: str):
123102
}
124103

125104
credentials_file = os.path.join(
126-
get_path(folders_config, credentials_path_key),
127-
pubmlst_session_credentials_file_name
105+
get_path(folders_config, credentials_path_key), pubmlst_session_credentials_file_name
128106
)
129107

130108
if os.path.exists(credentials_file):
@@ -141,16 +119,15 @@ def save_session_token(db: str, token: str, secret: str, expiration_date: str):
141119
with open(credentials_file, "w") as f:
142120
json.dump(all_sessions, f, indent=4)
143121

144-
logger.debug(
145-
f"Session token for database '{db}' saved to '{credentials_file}'."
146-
)
122+
logger.debug(f"Session token for database '{db}' saved to '{credentials_file}'.")
147123
except (IOError, OSError) as e:
148124
raise SaveSessionError(db, f"I/O error: {e}")
149125
except ValueError as e:
150126
raise SaveSessionError(db, f"Invalid data format: {e}")
151127
except Exception as e:
152128
raise SaveSessionError(db, f"Unexpected error: {e}")
153129

130+
154131
def parse_pubmlst_url(url: str):
155132
"""
156133
Match a URL against the URL map and return extracted parameters.

0 commit comments

Comments
 (0)