Skip to content

Commit de75ee5

Browse files
committed
move auth acquisition, segment source assignment, improve error handling
1 parent 5309284 commit de75ee5

File tree

1 file changed

+57
-24
lines changed

1 file changed

+57
-24
lines changed

sde_collections/sinequa_api.py

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -52,31 +52,62 @@
5252

5353

5454
class Api:
55-
def __init__(self, server_name: str = None) -> None:
55+
def __init__(self, server_name: str = None, user: str = None, password: str = None, token: str = None) -> None:
5656
self.server_name = server_name
57+
if server_name not in server_configs:
58+
raise ValueError(f"Server name '{server_name}' is not in server_configs")
59+
5760
config = server_configs[server_name]
5861
self.app_name: str = config["app_name"]
5962
self.query_name: str = config["query_name"]
6063
self.base_url: str = config["base_url"]
61-
self.user = getattr(settings, f"{server_name}_USER".upper(), None)
62-
self.password = getattr(settings, f"{server_name}_PASSWORD".upper(), None)
63-
self.token = getattr(settings, f"{server_name}_TOKEN".upper(), None)
64+
self.dev_servers = ["xli", "lrm_dev", "lrm_qa"]
65+
66+
# Store provided values only
67+
self._provided_user = user
68+
self._provided_password = password
69+
self._provided_token = token
70+
71+
def _get_user(self) -> str | None:
72+
"""Retrieve the user, using the provided value or defaulting to Django settings."""
73+
return self._provided_user or getattr(settings, f"{self.server_name}_USER".upper(), None)
74+
75+
def _get_password(self) -> str | None:
76+
"""Retrieve the password, using the provided value or defaulting to Django settings."""
77+
return self._provided_password or getattr(settings, f"{self.server_name}_PASSWORD".upper(), None)
78+
79+
def _get_token(self) -> str | None:
80+
"""Retrieve the token, using the provided value or defaulting to Django settings."""
81+
return self._provided_token or getattr(settings, f"{self.server_name}_TOKEN".upper(), None)
82+
83+
def _get_source_name(self) -> str:
84+
"""by default, the source is /SDE/. However for the various dev servers, the source is tends to be /scrapers/"""
85+
return "scrapers" if self.server_name in self.dev_servers else "SDE"
6486

6587
def process_response(self, url: str, payload: dict[str, Any]) -> Any:
6688
response = requests.post(url, headers={}, json=payload, verify=False)
6789

68-
if response.status_code == requests.status_codes.codes.ok:
69-
meaningful_response = response.json()
90+
if response.status_code == requests.codes.ok:
91+
return response.json()
7092
else:
71-
raise Exception(response.text)
72-
73-
return meaningful_response
93+
response.raise_for_status()
7494

7595
def query(self, page: int, collection_config_folder: str = "") -> Any:
76-
if self.server_name:
77-
url = f"{self.base_url}/api/v1/search.query?Password={self.password}&User={self.user}"
96+
url = f"{self.base_url}/api/v1/search.query"
97+
if self.server_name in self.dev_servers:
98+
99+
user = self._get_user()
100+
password = self._get_password()
101+
102+
if not user or not password:
103+
raise ValueError(
104+
"User and password are required for the query endpoint on the following servers: {self.dev_servers}"
105+
)
106+
authentication = f"?Password={password}&User={user}"
107+
url = f"{url}{authentication}"
78108
else:
79109
url = f"{self.base_url}/api/v1/search.query"
110+
80111
payload = {
81112
"app": self.app_name,
82113
"query": {
@@ -89,22 +120,19 @@ def query(self, page: int, collection_config_folder: str = "") -> Any:
89120
}
90121

91122
if collection_config_folder:
92-
if self.server_name in ["xli", "lrm_dev", "lrm_qa"]:
93-
payload["query"]["advanced"]["collection"] = f"/scrapers/{collection_config_folder}/"
94-
else:
95-
payload["query"]["advanced"]["collection"] = f"/SDE/{collection_config_folder}/"
96-
97-
response = self.process_response(url, payload)
123+
source = self._get_source_name()
124+
payload["query"]["advanced"]["collection"] = f"/{source}/{collection_config_folder}/"
98125

99-
return response
126+
return self.process_response(url, payload)
100127

101128
def sql_query(self, sql: str) -> Any:
102129
"""Executes an SQL query on the configured server using token-based authentication."""
103-
if not self.token:
104-
raise ValueError("You must have a token to use the SQL endpoint")
130+
token = self._get_token()
131+
if not token:
132+
raise ValueError("A token is required to use the SQL endpoint")
105133

106134
url = f"{self.base_url}/api/v1/engine.sql"
107-
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.token}"}
135+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {token}"}
108136
payload = json.dumps(
109137
{
110138
"method": "engine.sql",
@@ -116,14 +144,15 @@ def sql_query(self, sql: str) -> Any:
116144
"engines": "default",
117145
}
118146
)
147+
119148
try:
120149
response = requests.post(url, headers=headers, data=payload, timeout=10)
121150
response.raise_for_status()
122151
return response.json()
123152
except requests.exceptions.RequestException as e:
124-
raise Exception(f"API request failed: {str(e)}")
153+
raise RuntimeError(f"Api request to SQL endpoint failed: {str(e)}")
125154

126-
def get_full_texts(self, collection_config_folder: str) -> Any:
155+
def get_full_texts(self, collection_config_folder: str, source: str = None) -> Any:
127156
"""
128157
Retrieves the full texts, URLs, and titles for a specified collection.
129158
@@ -149,5 +178,9 @@ def get_full_texts(self, collection_config_folder: str) -> Any:
149178
}
150179
151180
"""
152-
sql = f"SELECT url1, text, title FROM sde_index WHERE collection = '/SDE/{collection_config_folder}/'"
181+
182+
if not source:
183+
source = self._get_source_name()
184+
185+
sql = f"SELECT url1, text, title FROM sde_index WHERE collection = '/{source}/{collection_config_folder}/'"
153186
return self.sql_query(sql)

0 commit comments

Comments
 (0)