1- import requests
21from urllib .parse import urlencode
2+
3+ import requests
4+ from rauth import OAuth1Session
5+
6+ from microSALT import logger
7+ from microSALT .utils .pubmlst .authentication import load_session_credentials
8+ from microSALT .utils .pubmlst .constants import HTTPMethod , RequestType , ResponseHandler
9+ from microSALT .utils .pubmlst .exceptions import PUBMLSTError , SessionTokenRequestError
310from microSALT .utils .pubmlst .helpers import (
411 BASE_API ,
5- generate_oauth_header ,
612 load_auth_credentials ,
7- parse_pubmlst_url
13+ parse_pubmlst_url ,
814)
9- from microSALT .utils .pubmlst .constants import RequestType , HTTPMethod , ResponseHandler
10- from microSALT .utils .pubmlst .exceptions import PUBMLSTError , SessionTokenRequestError
11- from microSALT .utils .pubmlst .authentication import load_session_credentials
12- from microSALT import logger
15+
1316
1417class PubMLSTClient :
1518 """Client for interacting with the PubMLST authenticated API."""
1619
1720 def __init__ (self ):
1821 """Initialize the PubMLST client."""
1922 try :
20- self .consumer_key , self .consumer_secret , self .access_token , self .access_secret = load_auth_credentials ()
23+ self .consumer_key , self .consumer_secret , self .access_token , self .access_secret = (
24+ load_auth_credentials ()
25+ )
2126 self .database = "pubmlst_test_seqdef"
2227 self .session_token , self .session_secret = load_session_credentials (self .database )
2328 except PUBMLSTError as e :
2429 logger .error (f"Failed to initialize PubMLST client: { e } " )
2530 raise
2631
27-
2832 @staticmethod
2933 def parse_pubmlst_url (url : str ):
3034 """
3135 Wrapper for the parse_pubmlst_url function.
3236 """
3337 return parse_pubmlst_url (url )
3438
35-
36- def _make_request (self , request_type : RequestType , method : HTTPMethod , url : str , db : str = None , response_handler : ResponseHandler = ResponseHandler .JSON ):
37- """ Handle API requests."""
39+ def _make_request (
40+ self ,
41+ request_type : RequestType ,
42+ method : HTTPMethod ,
43+ url : str ,
44+ db : str = None ,
45+ response_handler : ResponseHandler = ResponseHandler .JSON ,
46+ ):
47+ """Handle API requests."""
3848 try :
39- if db :
40- session_token , session_secret = load_session_credentials (db )
41- else :
42- session_token , session_secret = self .session_token , self .session_secret
43-
4449 if request_type == RequestType .AUTH :
45- headers = {
46- "Authorization" : generate_oauth_header ( url , self .consumer_key , self . consumer_secret , self . access_token , self . access_secret )
47- }
50+ access_token = self . access_token
51+ access_secret = self .access_secret
52+ log_database = "authentication"
4853 elif request_type == RequestType .DB :
49- headers = {
50- "Authorization" : generate_oauth_header (url , self .consumer_key , self .consumer_secret , session_token , session_secret )
51- }
54+ access_token , access_secret = load_session_credentials (db or self .database )
55+ log_database = db or self .database
5256 else :
5357 raise ValueError (f"Unsupported request type: { request_type } " )
5458
55- if method == HTTPMethod .GET :
56- response = requests .get (url , headers = headers )
57- elif method == HTTPMethod .POST :
58- response = requests .post (url , headers = headers )
59- elif method == HTTPMethod .PUT :
60- response = requests .put (url , headers = headers )
61- else :
62- raise ValueError (f"Unsupported HTTP method: { method } " )
63-
59+ logger .info (f"Making request to { url } for database { log_database } " )
60+ logger .info (f"Using session token: { access_token [:6 ]} ****" )
61+
62+ session = OAuth1Session (
63+ consumer_key = self .consumer_key ,
64+ consumer_secret = self .consumer_secret ,
65+ access_token = access_token ,
66+ access_token_secret = access_secret ,
67+ )
68+
69+ headers = {"User-Agent" : "BIGSdb API Client" }
70+ response = session .request (method .value , url , headers = headers )
71+
72+ if response .status_code == 401 :
73+ logger .error (f"401 Unauthorized: { response .text } " )
74+
6475 response .raise_for_status ()
65-
76+
77+ # Process the response based on the requested handler type
6678 if response_handler == ResponseHandler .CONTENT :
6779 return response .content
6880 elif response_handler == ResponseHandler .TEXT :
@@ -73,44 +85,51 @@ def _make_request(self, request_type: RequestType, method: HTTPMethod, url: str,
7385 raise ValueError (f"Unsupported response handler: { response_handler } " )
7486
7587 except requests .exceptions .HTTPError as e :
76- raise SessionTokenRequestError (db or self .database , e .response .status_code , e .response .text ) from e
88+ raise SessionTokenRequestError (
89+ db or self .database , e .response .status_code , e .response .text
90+ ) from e
7791 except requests .exceptions .RequestException as e :
7892 logger .error (f"Request failed: { e } " )
7993 raise PUBMLSTError (f"Request failed: { e } " ) from e
8094 except Exception as e :
8195 logger .error (f"Unexpected error during request: { e } " )
8296 raise PUBMLSTError (f"An unexpected error occurred: { e } " ) from e
8397
84-
8598 def query_databases (self ):
8699 """Query available PubMLST databases."""
87100 url = f"{ BASE_API } /db"
88- return self ._make_request (RequestType .DB , HTTPMethod .GET , url , response_handler = ResponseHandler .JSON )
89-
101+ return self ._make_request (
102+ RequestType .DB , HTTPMethod .GET , url , response_handler = ResponseHandler .JSON
103+ )
90104
91105 def download_locus (self , db : str , locus : str , ** kwargs ):
92106 """Download locus sequence files."""
93107 base_url = f"{ BASE_API } /db/{ db } /loci/{ locus } /alleles_fasta"
94108 query_string = urlencode (kwargs )
95109 url = f"{ base_url } ?{ query_string } " if query_string else base_url
96- return self ._make_request (RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .TEXT )
97-
110+ return self ._make_request (
111+ RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .TEXT
112+ )
98113
99114 def download_profiles_csv (self , db : str , scheme_id : int ):
100115 """Download MLST profiles in CSV format."""
101116 if not scheme_id :
102117 raise ValueError ("Scheme ID is required to download profiles CSV." )
103118 url = f"{ BASE_API } /db/{ db } /schemes/{ scheme_id } /profiles_csv"
104- return self ._make_request (RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .TEXT )
105-
119+ return self ._make_request (
120+ RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .TEXT
121+ )
106122
107123 def retrieve_scheme_info (self , db : str , scheme_id : int ):
108124 """Retrieve information about a specific MLST scheme."""
109125 url = f"{ BASE_API } /db/{ db } /schemes/{ scheme_id } "
110- return self ._make_request (RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .JSON )
111-
126+ return self ._make_request (
127+ RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .JSON
128+ )
112129
113130 def list_schemes (self , db : str ):
114131 """List available MLST schemes for a specific database."""
115132 url = f"{ BASE_API } /db/{ db } /schemes"
116- return self ._make_request (RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .JSON )
133+ return self ._make_request (
134+ RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .JSON
135+ )
0 commit comments