Skip to content

Commit 3957fe7

Browse files
committed
Fix OAuth request handling
1 parent d88b9f5 commit 3957fe7

File tree

1 file changed

+63
-43
lines changed

1 file changed

+63
-43
lines changed

microSALT/utils/pubmlst/client.py

Lines changed: 63 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,81 @@
1-
import requests
21
from urllib.parse import urlencode
2+
3+
import requests
4+
from rauth import OAuth1Session
5+
6+
from microSALT import logger
7+
from microSALT.utils.pubmlst.authentication import load_session_credentials
8+
from microSALT.utils.pubmlst.constants import HTTPMethod, RequestType, ResponseHandler
9+
from microSALT.utils.pubmlst.exceptions import PUBMLSTError, SessionTokenRequestError
310
from microSALT.utils.pubmlst.helpers import (
411
BASE_API,
512
generate_oauth_header,
613
load_auth_credentials,
7-
parse_pubmlst_url
14+
parse_pubmlst_url,
815
)
9-
from microSALT.utils.pubmlst.constants import RequestType, HTTPMethod, ResponseHandler
10-
from microSALT.utils.pubmlst.exceptions import PUBMLSTError, SessionTokenRequestError
11-
from microSALT.utils.pubmlst.authentication import load_session_credentials
12-
from microSALT import logger
16+
1317

1418
class PubMLSTClient:
1519
"""Client for interacting with the PubMLST authenticated API."""
1620

1721
def __init__(self):
1822
"""Initialize the PubMLST client."""
1923
try:
20-
self.consumer_key, self.consumer_secret, self.access_token, self.access_secret = load_auth_credentials()
24+
self.consumer_key, self.consumer_secret, self.access_token, self.access_secret = (
25+
load_auth_credentials()
26+
)
2127
self.database = "pubmlst_test_seqdef"
2228
self.session_token, self.session_secret = load_session_credentials(self.database)
2329
except PUBMLSTError as e:
2430
logger.error(f"Failed to initialize PubMLST client: {e}")
2531
raise
2632

27-
2833
@staticmethod
2934
def parse_pubmlst_url(url: str):
3035
"""
3136
Wrapper for the parse_pubmlst_url function.
3237
"""
3338
return parse_pubmlst_url(url)
3439

35-
36-
def _make_request(self, request_type: RequestType, method: HTTPMethod, url: str, db: str = None, response_handler: ResponseHandler = ResponseHandler.JSON):
37-
""" Handle API requests."""
40+
def _make_request(
41+
self,
42+
request_type: RequestType,
43+
method: HTTPMethod,
44+
url: str,
45+
db: str = None,
46+
response_handler: ResponseHandler = ResponseHandler.JSON,
47+
):
48+
"""Handle API requests."""
3849
try:
39-
if db:
40-
session_token, session_secret = load_session_credentials(db)
41-
else:
42-
session_token, session_secret = self.session_token, self.session_secret
43-
4450
if request_type == RequestType.AUTH:
45-
headers = {
46-
"Authorization": generate_oauth_header(url, self.consumer_key, self.consumer_secret, self.access_token, self.access_secret)
47-
}
51+
access_token = self.access_token
52+
access_secret = self.access_secret
53+
log_database = "authentication"
4854
elif request_type == RequestType.DB:
49-
headers = {
50-
"Authorization": generate_oauth_header(url, self.consumer_key, self.consumer_secret, session_token, session_secret)
51-
}
55+
access_token, access_secret = load_session_credentials(db or self.database)
56+
log_database = db or self.database
5257
else:
5358
raise ValueError(f"Unsupported request type: {request_type}")
5459

55-
if method == HTTPMethod.GET:
56-
response = requests.get(url, headers=headers)
57-
elif method == HTTPMethod.POST:
58-
response = requests.post(url, headers=headers)
59-
elif method == HTTPMethod.PUT:
60-
response = requests.put(url, headers=headers)
61-
else:
62-
raise ValueError(f"Unsupported HTTP method: {method}")
63-
60+
logger.info(f"Making request to {url} for database {log_database}")
61+
logger.info(f"Using session token: {access_token[:6]}****")
62+
63+
session = OAuth1Session(
64+
consumer_key=self.consumer_key,
65+
consumer_secret=self.consumer_secret,
66+
access_token=access_token,
67+
access_token_secret=access_secret,
68+
)
69+
70+
headers = {"User-Agent": "BIGSdb API Client"}
71+
response = session.request(method.value, url, headers=headers)
72+
73+
if response.status_code == 401:
74+
logger.error(f"401 Unauthorized: {response.text}")
75+
6476
response.raise_for_status()
65-
77+
78+
# Process the response based on the requested handler type
6679
if response_handler == ResponseHandler.CONTENT:
6780
return response.content
6881
elif response_handler == ResponseHandler.TEXT:
@@ -73,44 +86,51 @@ def _make_request(self, request_type: RequestType, method: HTTPMethod, url: str,
7386
raise ValueError(f"Unsupported response handler: {response_handler}")
7487

7588
except requests.exceptions.HTTPError as e:
76-
raise SessionTokenRequestError(db or self.database, e.response.status_code, e.response.text) from e
89+
raise SessionTokenRequestError(
90+
db or self.database, e.response.status_code, e.response.text
91+
) from e
7792
except requests.exceptions.RequestException as e:
7893
logger.error(f"Request failed: {e}")
7994
raise PUBMLSTError(f"Request failed: {e}") from e
8095
except Exception as e:
8196
logger.error(f"Unexpected error during request: {e}")
8297
raise PUBMLSTError(f"An unexpected error occurred: {e}") from e
8398

84-
8599
def query_databases(self):
86100
"""Query available PubMLST databases."""
87101
url = f"{BASE_API}/db"
88-
return self._make_request(RequestType.DB, HTTPMethod.GET, url, response_handler=ResponseHandler.JSON)
89-
102+
return self._make_request(
103+
RequestType.DB, HTTPMethod.GET, url, response_handler=ResponseHandler.JSON
104+
)
90105

91106
def download_locus(self, db: str, locus: str, **kwargs):
92107
"""Download locus sequence files."""
93108
base_url = f"{BASE_API}/db/{db}/loci/{locus}/alleles_fasta"
94109
query_string = urlencode(kwargs)
95110
url = f"{base_url}?{query_string}" if query_string else base_url
96-
return self._make_request(RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.TEXT)
97-
111+
return self._make_request(
112+
RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.TEXT
113+
)
98114

99115
def download_profiles_csv(self, db: str, scheme_id: int):
100116
"""Download MLST profiles in CSV format."""
101117
if not scheme_id:
102118
raise ValueError("Scheme ID is required to download profiles CSV.")
103119
url = f"{BASE_API}/db/{db}/schemes/{scheme_id}/profiles_csv"
104-
return self._make_request(RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.TEXT)
105-
120+
return self._make_request(
121+
RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.TEXT
122+
)
106123

107124
def retrieve_scheme_info(self, db: str, scheme_id: int):
108125
"""Retrieve information about a specific MLST scheme."""
109126
url = f"{BASE_API}/db/{db}/schemes/{scheme_id}"
110-
return self._make_request(RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.JSON)
111-
127+
return self._make_request(
128+
RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.JSON
129+
)
112130

113131
def list_schemes(self, db: str):
114132
"""List available MLST schemes for a specific database."""
115133
url = f"{BASE_API}/db/{db}/schemes"
116-
return self._make_request(RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.JSON)
134+
return self._make_request(
135+
RequestType.DB, HTTPMethod.GET, url, db=db, response_handler=ResponseHandler.JSON
136+
)

0 commit comments

Comments
 (0)