1- import requests
21from urllib .parse import urlencode
2+
3+ import requests
4+ from rauth import OAuth1Session
5+
6+ from microSALT import logger
7+ from microSALT .utils .pubmlst .authentication import load_session_credentials
8+ from microSALT .utils .pubmlst .constants import HTTPMethod , RequestType , ResponseHandler
9+ from microSALT .utils .pubmlst .exceptions import PUBMLSTError , SessionTokenRequestError
310from microSALT .utils .pubmlst .helpers import (
411 BASE_API ,
512 generate_oauth_header ,
613 load_auth_credentials ,
7- parse_pubmlst_url
14+ parse_pubmlst_url ,
815)
9- from microSALT .utils .pubmlst .constants import RequestType , HTTPMethod , ResponseHandler
10- from microSALT .utils .pubmlst .exceptions import PUBMLSTError , SessionTokenRequestError
11- from microSALT .utils .pubmlst .authentication import load_session_credentials
12- from microSALT import logger
16+
1317
1418class PubMLSTClient :
1519 """Client for interacting with the PubMLST authenticated API."""
1620
1721 def __init__ (self ):
1822 """Initialize the PubMLST client."""
1923 try :
20- self .consumer_key , self .consumer_secret , self .access_token , self .access_secret = load_auth_credentials ()
24+ self .consumer_key , self .consumer_secret , self .access_token , self .access_secret = (
25+ load_auth_credentials ()
26+ )
2127 self .database = "pubmlst_test_seqdef"
2228 self .session_token , self .session_secret = load_session_credentials (self .database )
2329 except PUBMLSTError as e :
2430 logger .error (f"Failed to initialize PubMLST client: { e } " )
2531 raise
2632
27-
2833 @staticmethod
2934 def parse_pubmlst_url (url : str ):
3035 """
3136 Wrapper for the parse_pubmlst_url function.
3237 """
3338 return parse_pubmlst_url (url )
3439
35-
36- def _make_request (self , request_type : RequestType , method : HTTPMethod , url : str , db : str = None , response_handler : ResponseHandler = ResponseHandler .JSON ):
37- """ Handle API requests."""
40+ def _make_request (
41+ self ,
42+ request_type : RequestType ,
43+ method : HTTPMethod ,
44+ url : str ,
45+ db : str = None ,
46+ response_handler : ResponseHandler = ResponseHandler .JSON ,
47+ ):
48+ """Handle API requests."""
3849 try :
39- if db :
40- session_token , session_secret = load_session_credentials (db )
41- else :
42- session_token , session_secret = self .session_token , self .session_secret
43-
4450 if request_type == RequestType .AUTH :
45- headers = {
46- "Authorization" : generate_oauth_header ( url , self .consumer_key , self . consumer_secret , self . access_token , self . access_secret )
47- }
51+ access_token = self . access_token
52+ access_secret = self .access_secret
53+ log_database = "authentication"
4854 elif request_type == RequestType .DB :
49- headers = {
50- "Authorization" : generate_oauth_header (url , self .consumer_key , self .consumer_secret , session_token , session_secret )
51- }
55+ access_token , access_secret = load_session_credentials (db or self .database )
56+ log_database = db or self .database
5257 else :
5358 raise ValueError (f"Unsupported request type: { request_type } " )
5459
55- if method == HTTPMethod .GET :
56- response = requests .get (url , headers = headers )
57- elif method == HTTPMethod .POST :
58- response = requests .post (url , headers = headers )
59- elif method == HTTPMethod .PUT :
60- response = requests .put (url , headers = headers )
61- else :
62- raise ValueError (f"Unsupported HTTP method: { method } " )
63-
60+ logger .info (f"Making request to { url } for database { log_database } " )
61+ logger .info (f"Using session token: { access_token [:6 ]} ****" )
62+
63+ session = OAuth1Session (
64+ consumer_key = self .consumer_key ,
65+ consumer_secret = self .consumer_secret ,
66+ access_token = access_token ,
67+ access_token_secret = access_secret ,
68+ )
69+
70+ headers = {"User-Agent" : "BIGSdb API Client" }
71+ response = session .request (method .value , url , headers = headers )
72+
73+ if response .status_code == 401 :
74+ logger .error (f"401 Unauthorized: { response .text } " )
75+
6476 response .raise_for_status ()
65-
77+
78+ # Process the response based on the requested handler type
6679 if response_handler == ResponseHandler .CONTENT :
6780 return response .content
6881 elif response_handler == ResponseHandler .TEXT :
@@ -73,44 +86,51 @@ def _make_request(self, request_type: RequestType, method: HTTPMethod, url: str,
7386 raise ValueError (f"Unsupported response handler: { response_handler } " )
7487
7588 except requests .exceptions .HTTPError as e :
76- raise SessionTokenRequestError (db or self .database , e .response .status_code , e .response .text ) from e
89+ raise SessionTokenRequestError (
90+ db or self .database , e .response .status_code , e .response .text
91+ ) from e
7792 except requests .exceptions .RequestException as e :
7893 logger .error (f"Request failed: { e } " )
7994 raise PUBMLSTError (f"Request failed: { e } " ) from e
8095 except Exception as e :
8196 logger .error (f"Unexpected error during request: { e } " )
8297 raise PUBMLSTError (f"An unexpected error occurred: { e } " ) from e
8398
84-
8599 def query_databases (self ):
86100 """Query available PubMLST databases."""
87101 url = f"{ BASE_API } /db"
88- return self ._make_request (RequestType .DB , HTTPMethod .GET , url , response_handler = ResponseHandler .JSON )
89-
102+ return self ._make_request (
103+ RequestType .DB , HTTPMethod .GET , url , response_handler = ResponseHandler .JSON
104+ )
90105
91106 def download_locus (self , db : str , locus : str , ** kwargs ):
92107 """Download locus sequence files."""
93108 base_url = f"{ BASE_API } /db/{ db } /loci/{ locus } /alleles_fasta"
94109 query_string = urlencode (kwargs )
95110 url = f"{ base_url } ?{ query_string } " if query_string else base_url
96- return self ._make_request (RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .TEXT )
97-
111+ return self ._make_request (
112+ RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .TEXT
113+ )
98114
99115 def download_profiles_csv (self , db : str , scheme_id : int ):
100116 """Download MLST profiles in CSV format."""
101117 if not scheme_id :
102118 raise ValueError ("Scheme ID is required to download profiles CSV." )
103119 url = f"{ BASE_API } /db/{ db } /schemes/{ scheme_id } /profiles_csv"
104- return self ._make_request (RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .TEXT )
105-
120+ return self ._make_request (
121+ RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .TEXT
122+ )
106123
107124 def retrieve_scheme_info (self , db : str , scheme_id : int ):
108125 """Retrieve information about a specific MLST scheme."""
109126 url = f"{ BASE_API } /db/{ db } /schemes/{ scheme_id } "
110- return self ._make_request (RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .JSON )
111-
127+ return self ._make_request (
128+ RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .JSON
129+ )
112130
113131 def list_schemes (self , db : str ):
114132 """List available MLST schemes for a specific database."""
115133 url = f"{ BASE_API } /db/{ db } /schemes"
116- return self ._make_request (RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .JSON )
134+ return self ._make_request (
135+ RequestType .DB , HTTPMethod .GET , url , db = db , response_handler = ResponseHandler .JSON
136+ )
0 commit comments