|
1 | 1 | import asyncio |
| 2 | +import re |
2 | 3 | import ssl |
3 | 4 | from abc import ABC, abstractmethod |
4 | 5 | from json import dumps, loads |
5 | 6 | from pathlib import Path |
6 | 7 | from typing import Any, cast |
7 | 8 | from urllib.parse import urljoin |
8 | 9 |
|
9 | | -from aiohttp import BasicAuth, ClientSession, TCPConnector |
| 10 | +from aiohttp import BasicAuth, ClientError, ClientSession, ClientTimeout, TCPConnector |
10 | 11 | from opentelemetry.trace import get_tracer |
11 | 12 |
|
12 | 13 | from cbltest.api.error import CblSyncGatewayBadResponseError |
@@ -181,32 +182,6 @@ def __init__(self, key: str, id: str, revid: str | None, cv: str | None) -> None |
181 | 182 | self.__cv = cv |
182 | 183 |
|
183 | 184 |
|
184 | | -class DatabaseStatusResponse: |
185 | | - """ |
186 | | - A class representing a database status response from Sync Gateway |
187 | | - """ |
188 | | - |
189 | | - @property |
190 | | - def db_name(self) -> str: |
191 | | - """Gets the database name""" |
192 | | - return self.__db_name |
193 | | - |
194 | | - @property |
195 | | - def state(self) -> str: |
196 | | - """Gets the database state ('Online', 'Offline', etc.)""" |
197 | | - return self.__state |
198 | | - |
199 | | - @property |
200 | | - def update_seq(self) -> int: |
201 | | - """Gets the update sequence number""" |
202 | | - return self.__update_seq |
203 | | - |
204 | | - def __init__(self, response: dict): |
205 | | - self.__db_name = response.get("db_name", "") |
206 | | - self.__state = response.get("state", "Unknown") |
207 | | - self.__update_seq = response.get("update_seq", 0) |
208 | | - |
209 | | - |
210 | 185 | class AllDocumentsResponse: |
211 | 186 | """ |
212 | 187 | A class representing an all_docs response from Sync Gateway |
@@ -426,6 +401,32 @@ def parse(self, input: str) -> tuple[str, int]: |
426 | 401 | return input[0:first_lparen], int(input[first_lparen + 1 : first_semicol]) |
427 | 402 |
|
428 | 403 |
|
| 404 | +class DatabaseStatusResponse: |
| 405 | + """ |
| 406 | + A class representing a database status response from Sync Gateway |
| 407 | + """ |
| 408 | + |
| 409 | + @property |
| 410 | + def db_name(self) -> str: |
| 411 | + """Gets the database name""" |
| 412 | + return self.__db_name |
| 413 | + |
| 414 | + @property |
| 415 | + def state(self) -> str: |
| 416 | + """Gets the database state ('Online', 'Offline', etc.)""" |
| 417 | + return self.__state |
| 418 | + |
| 419 | + @property |
| 420 | + def update_seq(self) -> int: |
| 421 | + """Gets the update sequence number""" |
| 422 | + return self.__update_seq |
| 423 | + |
| 424 | + def __init__(self, response: dict): |
| 425 | + self.__db_name = response.get("db_name", "") |
| 426 | + self.__state = response.get("state", "Unknown") |
| 427 | + self.__update_seq = response.get("update_seq", 0) |
| 428 | + |
| 429 | + |
429 | 430 | class _SyncGatewayBase: |
430 | 431 | """ |
431 | 432 | Base class for Sync Gateway clients containing common document and database operations. |
@@ -743,8 +744,7 @@ async def get_all_documents( |
743 | 744 | }, |
744 | 745 | ): |
745 | 746 | resp = await self._send_request( |
746 | | - "get", |
747 | | - f"/{db_name}.{scope}.{collection}/_all_docs", |
| 747 | + "get", f"/{db_name}.{scope}.{collection}/_all_docs" |
748 | 748 | ) |
749 | 749 |
|
750 | 750 | assert isinstance(resp, dict) |
@@ -775,8 +775,7 @@ async def get_changes( |
775 | 775 | ): |
776 | 776 | query_params = f"version_type={version_type}" |
777 | 777 | resp = await self._send_request( |
778 | | - "get", |
779 | | - f"/{db_name}.{scope}.{collection}/_changes?{query_params}", |
| 778 | + "get", f"/{db_name}.{scope}.{collection}/_changes?{query_params}" |
780 | 779 | ) |
781 | 780 |
|
782 | 781 | assert isinstance(resp, dict) |
@@ -1012,8 +1011,7 @@ async def get_document( |
1012 | 1011 | }, |
1013 | 1012 | ): |
1014 | 1013 | response = await self._send_request( |
1015 | | - "get", |
1016 | | - f"/{db_name}.{scope}.{collection}/{doc_id}", |
| 1014 | + "get", f"/{db_name}.{scope}.{collection}/{doc_id}" |
1017 | 1015 | ) |
1018 | 1016 | if not isinstance(response, dict): |
1019 | 1017 | raise ValueError( |
@@ -1098,6 +1096,238 @@ async def get_document_revision_public( |
1098 | 1096 | ) as session: |
1099 | 1097 | return await self._send_request("GET", path, params=params, session=session) |
1100 | 1098 |
|
| 1099 | + async def _caddy_http_request( |
| 1100 | + self, |
| 1101 | + url: str, |
| 1102 | + operation: str, |
| 1103 | + timeout: int = 30, |
| 1104 | + headers: dict[str, str] | None = None, |
| 1105 | + ) -> tuple[int, bytes]: |
| 1106 | + """ |
| 1107 | + Internal helper to make HTTP requests to Caddy server. |
| 1108 | +
|
| 1109 | + :param url: Full Caddy URL to request |
| 1110 | + :param operation: Description of operation (for error messages) |
| 1111 | + :param timeout: Request timeout in seconds |
| 1112 | + :param headers: Optional HTTP headers to include in the request |
| 1113 | + :return: Tuple of (status_code, content as bytes) |
| 1114 | + :raises FileNotFoundError: If resource returns 404 |
| 1115 | + :raises Exception: For other HTTP or network errors |
| 1116 | + """ |
| 1117 | + try: |
| 1118 | + async with ClientSession() as session: |
| 1119 | + async with session.get( |
| 1120 | + url, timeout=ClientTimeout(total=timeout), headers=headers |
| 1121 | + ) as response: |
| 1122 | + if response.status == 404: |
| 1123 | + raise FileNotFoundError(f"{operation} not found at {url}") |
| 1124 | + elif response.status != 200: |
| 1125 | + error_text = await response.text() |
| 1126 | + raise Exception( |
| 1127 | + f"{operation} failed: HTTP {response.status} - {error_text}" |
| 1128 | + ) |
| 1129 | + |
| 1130 | + # Return content as bytes |
| 1131 | + content = await response.read() |
| 1132 | + return response.status, content |
| 1133 | + |
| 1134 | + except ClientError as e: |
| 1135 | + raise Exception(f"Network error during {operation}: {e}") from e |
| 1136 | + |
| 1137 | + async def fetch_log_file( |
| 1138 | + self, |
| 1139 | + log_type: str, |
| 1140 | + ) -> str: |
| 1141 | + """ |
| 1142 | + Fetches a log file from the remote Sync Gateway server via Caddy HTTP server |
| 1143 | +
|
| 1144 | + :param log_type: Type of log file to fetch (e.g., 'debug', 'info', 'warn', 'error') |
| 1145 | + :return: Content of the log file as a string |
| 1146 | + :raises FileNotFoundError: If the log file doesn't exist |
| 1147 | + :raises Exception: For other HTTP errors |
| 1148 | + """ |
| 1149 | + log_filename = f"sg_{log_type}.log" |
| 1150 | + caddy_url = f"http://{self.hostname}:20000/{log_filename}" |
| 1151 | + |
| 1152 | + with self._tracer.start_as_current_span( |
| 1153 | + "fetch_log_file", |
| 1154 | + attributes={ |
| 1155 | + "cbl.log.type": log_type, |
| 1156 | + "cbl.log.filename": log_filename, |
| 1157 | + "cbl.caddy.url": caddy_url, |
| 1158 | + }, |
| 1159 | + ): |
| 1160 | + _, content = await self._caddy_http_request( |
| 1161 | + caddy_url, f"Fetch {log_filename}", timeout=30 |
| 1162 | + ) |
| 1163 | + log_content = content.decode("utf-8") |
| 1164 | + cbl_info(f"Successfully fetched {log_filename} ({len(log_content)} bytes)") |
| 1165 | + return log_content |
| 1166 | + |
| 1167 | + async def download_file_via_caddy( |
| 1168 | + self, |
| 1169 | + remote_filename: str, |
| 1170 | + local_path: str, |
| 1171 | + ) -> None: |
| 1172 | + """ |
| 1173 | + Downloads a file from the remote server via Caddy HTTP server |
| 1174 | +
|
| 1175 | + :param remote_filename: Name of the file on the remote server (e.g., 'sgcollectinfo-xxx-redacted.zip') |
| 1176 | + :param local_path: Local path where the file should be saved |
| 1177 | + :raises FileNotFoundError: If the file doesn't exist |
| 1178 | + :raises Exception: For other HTTP errors |
| 1179 | + """ |
| 1180 | + caddy_url = f"http://{self.hostname}:20000/{remote_filename}" |
| 1181 | + |
| 1182 | + with self._tracer.start_as_current_span( |
| 1183 | + "download_file_via_caddy", |
| 1184 | + attributes={ |
| 1185 | + "cbl.remote.filename": remote_filename, |
| 1186 | + "cbl.local.path": local_path, |
| 1187 | + "cbl.caddy.url": caddy_url, |
| 1188 | + }, |
| 1189 | + ): |
| 1190 | + _, content = await self._caddy_http_request( |
| 1191 | + caddy_url, f"Download {remote_filename}", timeout=120 |
| 1192 | + ) |
| 1193 | + |
| 1194 | + # Ensure local directory exists and write file |
| 1195 | + local_file_path = Path(local_path) |
| 1196 | + local_file_path.parent.mkdir(parents=True, exist_ok=True) |
| 1197 | + local_file_path.write_bytes(content) |
| 1198 | + |
| 1199 | + cbl_info( |
| 1200 | + f"Successfully downloaded {remote_filename} to {local_path} ({len(content)} bytes)" |
| 1201 | + ) |
| 1202 | + |
| 1203 | + async def list_files_via_caddy( |
| 1204 | + self, |
| 1205 | + pattern: str | None = None, |
| 1206 | + ) -> list[str]: |
| 1207 | + """ |
| 1208 | + Lists files available in the Caddy-served directory (requires 'browse' enabled in Caddyfile) |
| 1209 | +
|
| 1210 | + :param pattern: Optional regex pattern to filter filenames (e.g., 'sgcollect_info.*redacted.zip') |
| 1211 | + :return: List of filenames available in the directory |
| 1212 | + :raises Exception: If directory browsing is not enabled or request fails |
| 1213 | + """ |
| 1214 | + caddy_url = f"http://{self.hostname}:20000/" |
| 1215 | + |
| 1216 | + with self._tracer.start_as_current_span( |
| 1217 | + "list_files_via_caddy", |
| 1218 | + attributes={ |
| 1219 | + "cbl.caddy.url": caddy_url, |
| 1220 | + "cbl.pattern": pattern or "all", |
| 1221 | + }, |
| 1222 | + ): |
| 1223 | + try: |
| 1224 | + _, content = await self._caddy_http_request( |
| 1225 | + caddy_url, |
| 1226 | + "List directory", |
| 1227 | + timeout=30, |
| 1228 | + headers={"Accept": "application/json"}, |
| 1229 | + ) |
| 1230 | + except FileNotFoundError: |
| 1231 | + raise Exception( |
| 1232 | + "Directory browsing endpoint not found. " |
| 1233 | + "Ensure Caddy is configured with 'file_server browse'" |
| 1234 | + ) |
| 1235 | + |
| 1236 | + # Parse JSON response from Caddy |
| 1237 | + try: |
| 1238 | + dir_listing = loads(content.decode("utf-8")) |
| 1239 | + except ValueError as e: |
| 1240 | + raise Exception(f"Failed to parse Caddy JSON response: {e}") |
| 1241 | + |
| 1242 | + # Extract filenames from the JSON array |
| 1243 | + files = [ |
| 1244 | + entry["name"] |
| 1245 | + for entry in dir_listing |
| 1246 | + if isinstance(entry, dict) |
| 1247 | + and "name" in entry |
| 1248 | + and not entry.get("is_dir", False) |
| 1249 | + ] |
| 1250 | + |
| 1251 | + # Filter by pattern if provided |
| 1252 | + if pattern: |
| 1253 | + regex = re.compile(pattern) |
| 1254 | + files = [f for f in files if regex.search(f)] |
| 1255 | + |
| 1256 | + cbl_info( |
| 1257 | + f"Found {len(files)} files via Caddy browse (JSON)" |
| 1258 | + + (f" (filtered by '{pattern}')" if pattern else "") |
| 1259 | + ) |
| 1260 | + return files |
| 1261 | + |
| 1262 | + async def start_sgcollect( |
| 1263 | + self, |
| 1264 | + redact_level: str | None = None, |
| 1265 | + redact_salt: str | None = None, |
| 1266 | + output_dir: str | None = None, |
| 1267 | + ) -> dict: |
| 1268 | + """ |
| 1269 | + Starts SGCollect using the REST API endpoint |
| 1270 | +
|
| 1271 | + :param redact_level: Redaction level ('none', 'partial', 'full') |
| 1272 | + :param redact_salt: Custom salt for redaction hashing |
| 1273 | + :param output_dir: Output directory on the remote server |
| 1274 | + :return: Response dict with status |
| 1275 | + """ |
| 1276 | + with self._tracer.start_as_current_span( |
| 1277 | + "start_sgcollect", |
| 1278 | + attributes={ |
| 1279 | + "redact.level": redact_level or "none", |
| 1280 | + }, |
| 1281 | + ): |
| 1282 | + body: dict[str, Any] = {"upload": False} |
| 1283 | + if redact_level is not None: |
| 1284 | + body["redact_level"] = redact_level |
| 1285 | + if redact_salt is not None: |
| 1286 | + body["redact_salt"] = redact_salt |
| 1287 | + if output_dir is not None: |
| 1288 | + body["output_dir"] = output_dir |
| 1289 | + |
| 1290 | + resp = await self._send_request( |
| 1291 | + "post", |
| 1292 | + "/_sgcollect_info", |
| 1293 | + JSONDictionary(body), |
| 1294 | + ) |
| 1295 | + assert isinstance(resp, dict) |
| 1296 | + return cast(dict, resp) |
| 1297 | + |
| 1298 | + async def get_sgcollect_status(self) -> dict: |
| 1299 | + """ |
| 1300 | + Gets the current status of SGCollect operation |
| 1301 | +
|
| 1302 | + :return: Response dict with status ('stopped' or 'running') |
| 1303 | + """ |
| 1304 | + with self._tracer.start_as_current_span("get_sgcollect_status"): |
| 1305 | + resp = await self._send_request("get", "/_sgcollect_info") |
| 1306 | + assert isinstance(resp, dict) |
| 1307 | + return cast(dict, resp) |
| 1308 | + |
| 1309 | + async def wait_for_sgcollect_to_complete( |
| 1310 | + self, max_attempts: int = 60, wait_time: int = 2 |
| 1311 | + ) -> None: |
| 1312 | + """ |
| 1313 | + Waits for SGCollect to complete, polling until the status is 'stopped' or 'completed'. |
| 1314 | + Polls 60 times, waiting 2 seconds between each poll. |
| 1315 | +
|
| 1316 | + :param max_attempts: Maximum number of attempts to wait for SGCollect to complete |
| 1317 | + :param wait_time: Time to wait between attempts |
| 1318 | + """ |
| 1319 | + for _ in range(max_attempts): |
| 1320 | + status_resp = await self.get_sgcollect_status() |
| 1321 | + if status_resp.get("status") in ["stopped", "completed"]: |
| 1322 | + return |
| 1323 | + await asyncio.sleep(wait_time) |
| 1324 | + |
| 1325 | + raise Exception( |
| 1326 | + f"SGCollect did not complete after {max_attempts * wait_time} seconds.\n" |
| 1327 | + f"Status: {status_resp.get('status')}.\n" |
| 1328 | + f"Error: {status_resp.get('error')}" |
| 1329 | + ) |
| 1330 | + |
1101 | 1331 |
|
1102 | 1332 | class SyncGateway(_SyncGatewayBase): |
1103 | 1333 | """ |
@@ -1304,15 +1534,6 @@ class SyncGatewayUserClient(_SyncGatewayBase): |
1304 | 1534 |
|
1305 | 1535 | This class inherits common operations from _SyncGatewayBase and does NOT |
1306 | 1536 | include admin methods (user management, roles, etc.). |
1307 | | -
|
1308 | | - Use SyncGateway.create_user_client() to create instances with proper user credentials |
1309 | | - and channel access. |
1310 | | -
|
1311 | | - Example: |
1312 | | - admin_sg = SyncGateway("localhost", "admin", "password") |
1313 | | - user_sg = await admin_sg.create_user_client("db", "alice", "pass", ["channel1"]) |
1314 | | - # user_sg automatically uses port 4984 for all API calls |
1315 | | - docs = await user_sg.get_all_documents("db") |
1316 | 1537 | """ |
1317 | 1538 |
|
1318 | 1539 | def __init__( |
|
0 commit comments