|
| 1 | +import logging |
1 | 2 | import os |
2 | 3 | import time |
3 | | -from dataclasses import dataclass, field |
| 4 | +from dataclasses import asdict, dataclass, field |
| 5 | +from json import JSONDecodeError |
| 6 | +from json import dumps as json_dumps |
| 7 | +from json import loads as json_loads |
| 8 | +from os import makedirs, path |
4 | 9 | from typing import ( |
5 | 10 | Any, |
6 | 11 | Callable, |
|
12 | 17 | TypeVar, |
13 | 18 | ) |
14 | 19 |
|
| 20 | +from appdirs import user_data_dir |
| 21 | + |
| 22 | +from firebolt.utils.file_operations import ( |
| 23 | + FernetEncrypter, |
| 24 | + generate_encrypted_file_name, |
| 25 | + generate_salt, |
| 26 | +) |
| 27 | + |
15 | 28 | T = TypeVar("T") |
16 | 29 |
|
17 | 30 | # Cache expiry configuration |
18 | 31 | CACHE_EXPIRY_SECONDS = 3600 # 1 hour |
| 32 | +APPNAME = "firebolt" |
| 33 | + |
| 34 | +logger = logging.getLogger(__name__) |
19 | 35 |
|
20 | 36 |
|
21 | 37 | class ReprCacheable(Protocol): |
@@ -47,6 +63,22 @@ class ConnectionInfo: |
47 | 63 | system_engine: Optional[EngineInfo] = None |
48 | 64 | databases: Dict[str, DatabaseInfo] = field(default_factory=dict) |
49 | 65 | engines: Dict[str, EngineInfo] = field(default_factory=dict) |
| 66 | + token: Optional[str] = None |
| 67 | + |
| 68 | + def __post_init__(self) -> None: |
| 69 | + """ |
| 70 | + Post-initialization processing to convert dicts to dataclasses. |
| 71 | + """ |
| 72 | + if self.system_engine and isinstance(self.system_engine, dict): |
| 73 | + self.system_engine = EngineInfo(**self.system_engine) |
| 74 | + self.databases = { |
| 75 | + k: DatabaseInfo(**v) |
| 76 | + for k, v in self.databases.items() |
| 77 | + if isinstance(v, dict) |
| 78 | + } |
| 79 | + self.engines = { |
| 80 | + k: EngineInfo(**v) for k, v in self.engines.items() if isinstance(v, dict) |
| 81 | + } |
50 | 82 |
|
51 | 83 |
|
52 | 84 | def noop_if_disabled(func: Callable) -> Callable: |
@@ -150,4 +182,118 @@ def __hash__(self) -> int: |
150 | 182 | return hash(self.key) |
151 | 183 |
|
152 | 184 |
|
153 | | -_firebolt_cache = UtilCache[ConnectionInfo](cache_name="connection_info") |
| 185 | +class FileBasedCache(UtilCache[ConnectionInfo]): |
| 186 | + """ |
| 187 | + File-based cache that persists to disk with encryption. |
| 188 | + Extends UtilCache to provide persistent storage using encrypted files. |
| 189 | + """ |
| 190 | + |
| 191 | + def __init__(self, cache_name: str = ""): |
| 192 | + super().__init__(cache_name) |
| 193 | + self._data_dir = user_data_dir(appname=APPNAME) # TODO: change to new dir |
| 194 | + makedirs(self._data_dir, exist_ok=True) |
| 195 | + |
| 196 | + def _get_file_path(self, key: SecureCacheKey) -> str: |
| 197 | + """Get the file path for a cache key.""" |
| 198 | + cache_key = self.create_key(key) |
| 199 | + encrypted_filename = generate_encrypted_file_name(cache_key, key.encryption_key) |
| 200 | + return path.join(self._data_dir, encrypted_filename) |
| 201 | + |
| 202 | + def _read_data_json(self, file_path: str, encrypter: FernetEncrypter) -> dict: |
| 203 | + """Read and decrypt JSON data from file.""" |
| 204 | + if not path.exists(file_path): |
| 205 | + return {} |
| 206 | + |
| 207 | + try: |
| 208 | + with open(file_path, "r") as f: |
| 209 | + encrypted_data = f.read() |
| 210 | + |
| 211 | + decrypted_data = encrypter.decrypt(encrypted_data) |
| 212 | + if decrypted_data is None: |
| 213 | + logger.debug("Decryption failed for %s", file_path) |
| 214 | + return {} |
| 215 | + |
| 216 | + return json_loads(decrypted_data) if decrypted_data else {} |
| 217 | + except (JSONDecodeError, IOError) as e: |
| 218 | + logger.debug( |
| 219 | + "Failed to read or decode data from %s error: %s", file_path, e |
| 220 | + ) |
| 221 | + return {} |
| 222 | + |
| 223 | + def _write_data_json( |
| 224 | + self, file_path: str, data: dict, encrypter: FernetEncrypter |
| 225 | + ) -> None: |
| 226 | + """Encrypt and write JSON data to file.""" |
| 227 | + try: |
| 228 | + json_str = json_dumps(data) |
| 229 | + logger.debug("Writing data to %s", file_path) |
| 230 | + encrypted_data = encrypter.encrypt(json_str) |
| 231 | + with open(file_path, "w") as f: |
| 232 | + f.write(encrypted_data) |
| 233 | + except (IOError, OSError) as e: |
| 234 | + # Silently proceed if we can't write to disk |
| 235 | + logger.debug("Failed to write data to %s error: %s", file_path, e) |
| 236 | + |
| 237 | + def get(self, key: SecureCacheKey) -> Optional[ConnectionInfo]: |
| 238 | + """Get value from cache, checking both memory and disk.""" |
| 239 | + if self.disabled: |
| 240 | + return None |
| 241 | + |
| 242 | + # First try memory cache |
| 243 | + memory_result = super().get(key) |
| 244 | + if memory_result is not None: |
| 245 | + logger.debug("Cache hit in memory") |
| 246 | + return memory_result |
| 247 | + |
| 248 | + # If not in memory, try to load from disk |
| 249 | + file_path = self._get_file_path(key) |
| 250 | + encrypter = FernetEncrypter(generate_salt(), key.encryption_key) |
| 251 | + raw_data = self._read_data_json(file_path, encrypter) |
| 252 | + if not raw_data: |
| 253 | + return None |
| 254 | + logger.debug("Cache hit on disk") |
| 255 | + data = ConnectionInfo(**raw_data) |
| 256 | + |
| 257 | + # Add to memory cache and return |
| 258 | + super().set(key, data) |
| 259 | + return data |
| 260 | + |
| 261 | + def set(self, key: SecureCacheKey, value: ConnectionInfo) -> None: |
| 262 | + """Set value in both memory and disk cache.""" |
| 263 | + if self.disabled: |
| 264 | + return |
| 265 | + |
| 266 | + logger.debug("Setting value in cache") |
| 267 | + # First set in memory |
| 268 | + super().set(key, value) |
| 269 | + |
| 270 | + file_path = self._get_file_path(key) |
| 271 | + encrypter = FernetEncrypter(generate_salt(), key.encryption_key) |
| 272 | + data = asdict(value) |
| 273 | + |
| 274 | + self._write_data_json(file_path, data, encrypter) |
| 275 | + |
| 276 | + def delete(self, key: SecureCacheKey) -> None: |
| 277 | + """Delete value from both memory and disk cache.""" |
| 278 | + if self.disabled: |
| 279 | + return |
| 280 | + |
| 281 | + # Delete from memory |
| 282 | + super().delete(key) |
| 283 | + |
| 284 | + # Delete from disk |
| 285 | + file_path = self._get_file_path(key) |
| 286 | + try: |
| 287 | + if path.exists(file_path): |
| 288 | + os.remove(file_path) |
| 289 | + except OSError: |
| 290 | + logger.debug("Failed to delete file %s", file_path) |
| 291 | + # Silently proceed if we can't delete the file |
| 292 | + |
| 293 | + def clear(self) -> None: |
| 294 | + # Clear memory only, as deleting every file is not safe |
| 295 | + logger.debug("Clearing memory cache") |
| 296 | + super().clear() |
| 297 | + |
| 298 | + |
| 299 | +_firebolt_cache = FileBasedCache(cache_name="connection_info") |
0 commit comments