diff --git a/cli.py b/cli.py index 6603332..ee25e71 100644 --- a/cli.py +++ b/cli.py @@ -9,19 +9,30 @@ from commands.analyze_kinds import analyze_kinds, print_summary_table from commands.analyze_entity_fields import analyze_field_contributions, print_field_summary from commands.cleanup_expired import cleanup_expired +from commands.drive_sync import push_to_drive, pull_from_drive -app = typer.Typer(help="Utilities for analyzing and managing local Datastore/Firestore (Datastore mode)", no_args_is_help=True) +app = typer.Typer( + help="Utilities for analyzing and managing local Datastore/Firestore (Datastore mode)", + no_args_is_help=True, +) # Aliases with flags only — no defaults here ConfigOpt = Annotated[Optional[str], typer.Option("--config", help="Path to config.yaml")] ProjectOpt = Annotated[Optional[str], typer.Option("--project", help="GCP/Emulator project id")] -EmulatorHostOpt = Annotated[Optional[str], typer.Option("--emulator-host", help="Emulator host, e.g. localhost:8010")] +EmulatorHostOpt = Annotated[ + Optional[str], typer.Option("--emulator-host", help="Emulator host, e.g. localhost:8010") +] LogLevelOpt = Annotated[Optional[str], typer.Option("--log-level", help="Logging level")] KindsOpt = Annotated[ Optional[List[str]], - typer.Option("--kind", "-k", help="Kinds to process (omit or empty to process all in each namespace)") + typer.Option( + "--kind", "-k", help="Kinds to process (omit or empty to process all in each namespace)" + ), +] +SingleKindOpt = Annotated[ + Optional[str], typer.Option("--kind", "-k", help="Kind to analyze (falls back to config.kind)") ] -SingleKindOpt = Annotated[Optional[str], typer.Option("--kind", "-k", help="Kind to analyze (falls back to config.kind)")] + def _load_cfg( config_path: Optional[str], @@ -38,6 +49,7 @@ def _load_cfg( overrides["log_level"] = log_level return load_config(config_path, overrides) + @app.command("analyze-kinds") def cmd_analyze_kinds( config: ConfigOpt = None, @@ -64,17 +76,31 @@ def cmd_analyze_kinds( else: print_summary_table(rows) + @app.command("analyze-fields") def cmd_analyze_fields( kind: SingleKindOpt = None, - namespace: Annotated[Optional[str], typer.Option("--namespace", "-n", help="Namespace to query (omit to use all)")] = None, - group_by: Annotated[Optional[str], typer.Option("--group-by", help="Group results by this field value (falls back to config.group_by_field)")] = None, - only_field: Annotated[Optional[List[str]], typer.Option("--only-field", help="Only consider these fields")] = None, + namespace: Annotated[ + Optional[str], + typer.Option("--namespace", "-n", help="Namespace to query (omit to use all)"), + ] = None, + group_by: Annotated[ + Optional[str], + typer.Option( + "--group-by", + help="Group results by this field value (falls back to config.group_by_field)", + ), + ] = None, + only_field: Annotated[ + Optional[List[str]], typer.Option("--only-field", help="Only consider these fields") + ] = None, config: ConfigOpt = None, project: ProjectOpt = None, emulator_host: EmulatorHostOpt = None, log_level: LogLevelOpt = None, - output_json: Annotated[Optional[str], typer.Option("--output-json", help="Write raw JSON results to file")] = None, + output_json: Annotated[ + Optional[str], typer.Option("--output-json", help="Write raw JSON results to file") + ] = None, ): cfg = _load_cfg(config, project, emulator_host, log_level) @@ -100,6 +126,7 @@ def cmd_analyze_fields( else: print_field_summary(result) + @app.command("cleanup") def cmd_cleanup( config: ConfigOpt = None, @@ -107,10 +134,24 @@ def cmd_cleanup( emulator_host: EmulatorHostOpt = None, log_level: LogLevelOpt = None, kind: KindsOpt = None, - ttl_field: Annotated[Optional[str], typer.Option("--ttl-field", help="TTL field name (falls back to config.ttl_field)")] = None, - delete_missing_ttl: Annotated[Optional[bool], typer.Option("--delete-missing-ttl", help="Delete when TTL field is missing (falls back to config.delete_missing_ttl)")] = None, - batch_size: Annotated[Optional[int], typer.Option("--batch-size", help="Delete batch size (falls back to config.batch_size)")] = None, - dry_run: Annotated[bool, typer.Option("--dry-run", help="Only report counts; do not delete")] = False, + ttl_field: Annotated[ + Optional[str], + typer.Option("--ttl-field", help="TTL field name (falls back to config.ttl_field)"), + ] = None, + delete_missing_ttl: Annotated[ + Optional[bool], + typer.Option( + "--delete-missing-ttl", + help="Delete when TTL field is missing (falls back to config.delete_missing_ttl)", + ), + ] = None, + batch_size: Annotated[ + Optional[int], + typer.Option("--batch-size", help="Delete batch size (falls back to config.batch_size)"), + ] = None, + dry_run: Annotated[ + bool, typer.Option("--dry-run", help="Only report counts; do not delete") + ] = False, ): cfg = _load_cfg(config, project, emulator_host, log_level) @@ -127,6 +168,81 @@ def cmd_cleanup( deleted_sum = sum(totals.values()) typer.echo(f"Total entities {'to delete' if dry_run else 'deleted'}: {deleted_sum}") + +db_app = typer.Typer(help="Database backup management commands", no_args_is_help=True) + + +@db_app.command("push") +def db_push( + version: Annotated[ + Optional[str], typer.Argument(help="Version name (defaults to today's date YYYY-mm-DD)") + ] = None, + overwrite: Annotated[ + bool, typer.Option("-o", "--overwrite", help="Overwrite existing file with same name") + ] = False, + local_db: Annotated[ + Optional[str], + typer.Option( + "--local-db", + help="Optional helper script path (e.g. tools/dev-env/local-db). This script may stash/restore; the actual data file comes from config.local_db_path." + ), + ] = None, + dry_run: Annotated[ + bool, + typer.Option( + "--dry-run", help="Do not upload to Drive, just show what would be uploaded" + ), + ] = False, + config: ConfigOpt = None, + log_level: LogLevelOpt = None, +): + cfg = _load_cfg(config, None, None, log_level) + push_to_drive(cfg, version, overwrite, local_db, dry_run=dry_run) + + +@db_app.command("pull") +def db_pull( + version: Annotated[ + Optional[str], typer.Argument(help="Version name (omit to download latest)") + ] = None, + local_db: Annotated[ + Optional[str], + typer.Option( + "--local-db", + help="Optional helper script path (e.g. tools/dev-env/local-db). This script may stash/restore; the actual data file comes from config.local_db_path.", + ), + ] = None, + overwrite: Annotated[ + bool, + typer.Option( + "--overwrite/--no-overwrite", + help="Whether to overwrite the local data file when restoring from Drive (default: overwrite)", + show_default=True, + ), + ] = True, + config: ConfigOpt = None, + log_level: LogLevelOpt = None, +): + cfg = _load_cfg(config, None, None, log_level) + pull_from_drive(cfg, version, local_db, overwrite=overwrite) + + +@db_app.command("list") +def db_list( + config: ConfigOpt = None, + log_level: LogLevelOpt = None, +): + cfg = _load_cfg(config, None, None, log_level) + from commands.drive_sync import list_backups + + backups = list_backups(cfg) + for b in backups: + typer.echo(b) + + +app.add_typer(db_app, name="db") + + if __name__ == "__main__": import sys diff --git a/commands/__init__.py b/commands/__init__.py index a225483..5744e4f 100644 --- a/commands/__init__.py +++ b/commands/__init__.py @@ -2,20 +2,23 @@ from .analyze_kinds import analyze_kinds, get_kind_stats, estimate_entity_count_and_size from .analyze_entity_fields import analyze_field_contributions, print_field_summary from .cleanup_expired import cleanup_expired +from .drive_sync import push_to_drive, pull_from_drive from . import config as config __all__ = [ - "AppConfig", - "load_config", - "build_client", - "list_namespaces", - "list_kinds", - "format_size", - "analyze_kinds", - "get_kind_stats", - "estimate_entity_count_and_size", - "analyze_field_contributions", - "print_field_summary", - "cleanup_expired", - "config", + "AppConfig", + "load_config", + "build_client", + "list_namespaces", + "list_kinds", + "format_size", + "analyze_kinds", + "get_kind_stats", + "estimate_entity_count_and_size", + "analyze_field_contributions", + "print_field_summary", + "cleanup_expired", + "push_to_drive", + "pull_from_drive", + "config", ] diff --git a/commands/config.py b/commands/config.py index e7d4ba5..99f3c71 100644 --- a/commands/config.py +++ b/commands/config.py @@ -37,6 +37,11 @@ class AppConfig: # Logging log_level: str = "INFO" + # Drive sync settings + local_db_path: Optional[str] = None + # Google Drive folder name where backups are stored + gdrive_directory: str = "datastore" + def _as_list(value: Optional[Iterable[str]]) -> List[str]: if value is None: @@ -90,6 +95,9 @@ def load_config(path: Optional[str] = None, overrides: Optional[Dict] = None) -> config.log_level = str(merged.get("log_level", config.log_level)).upper() + config.local_db_path = merged.get("local_db_path", config.local_db_path) + config.gdrive_directory = merged.get("gdrive_directory", config.gdrive_directory) + _configure_logging(config.log_level) return config @@ -155,4 +163,3 @@ def format_size(bytes_size: int) -> str: return f"{size:.2f} {unit}" size /= 1024 return f"{size:.2f} PB" - diff --git a/commands/drive_sync.py b/commands/drive_sync.py new file mode 100644 index 0000000..42172e0 --- /dev/null +++ b/commands/drive_sync.py @@ -0,0 +1,249 @@ +from __future__ import annotations + +import logging +import os +import subprocess +from datetime import datetime +from typing import Optional, List +import shutil + +from pydrive2.auth import GoogleAuth +from pydrive2.drive import GoogleDrive + +try: + # Optional imports for ADC/google-api fallback + import google.auth + from googleapiclient.discovery import build + from googleapiclient.http import MediaFileUpload, MediaIoBaseDownload +except Exception: + google = None + +from .config import AppConfig + +logger = logging.getLogger(__name__) + + +def _authenticate_drive() -> GoogleDrive: + gauth = GoogleAuth() + # Interactive flow: open local webserver and prompt user to authenticate in browser + gauth.LocalWebserverAuth() + return GoogleDrive(gauth) + + +def _get_local_db_path(config: AppConfig, local_db_override: Optional[str]) -> str: + local_db = local_db_override or config.local_db_path + if not local_db: + raise ValueError("local-db path must be provided via --local-db or config.local_db_path") + if not os.path.exists(local_db): + raise FileNotFoundError(f"local-db binary not found at: {local_db}") + return local_db + + +def _get_or_create_gdrive_folder(drive: GoogleDrive, folder_name: str) -> str: + file_list = drive.ListFile( + { + "q": f"title='{folder_name}' and mimeType='application/vnd.google-apps.folder' and trashed=false" + } + ).GetList() + if file_list: + return file_list[0]["id"] + folder = drive.CreateFile({"title": folder_name, "mimeType": "application/vnd.google-apps.folder"}) + folder.Upload() + logger.info(f"Created /{folder_name} folder in Google Drive") + return folder["id"] + + +def _get_adc_drive_service(): + """Return a googleapiclient Drive service when ADC is available and has Drive scope, otherwise None. + + We avoid using ADC unless the obtained credentials explicitly include the Drive scope. This + prevents accidentally selecting ADC in test environments or runtime environments where the + default credentials don't have Drive permissions (which would cause 403 errors later). + """ + try: + if google is None: + return None + DRIVE_SCOPE = "https://www.googleapis.com/auth/drive" + # Try to get ADC without forcing scopes first; some environments may provide scoped creds + creds, _ = google.auth.default() + + scopes = getattr(creds, "scopes", None) + # If scopes are not present or don't include the Drive scope, try requesting the Drive scope + if not scopes or DRIVE_SCOPE not in scopes: + try: + creds, _ = google.auth.default(scopes=[DRIVE_SCOPE]) + scopes = getattr(creds, "scopes", None) + except Exception: + return None + + if not scopes or DRIVE_SCOPE not in scopes: + return None + + service = build("drive", "v3", credentials=creds, cache_discovery=False) + return service + except Exception: + return None + + +def list_backups(config: AppConfig) -> List[str]: + drive = _authenticate_drive() + folder_id = _get_or_create_gdrive_folder(drive, config.gdrive_directory) + files = drive.ListFile({"q": f"'{folder_id}' in parents and trashed=false"}).GetList() + return [f["title"] for f in files] + + +def _run_local_db_command(local_db_script: str, args: list[str]) -> str: + # Run helper script (if provided) and return stdout. The script is optional; callers may ignore the output. + if not os.path.exists(local_db_script): + raise FileNotFoundError(f"local-db script not found at: {local_db_script}") + if not os.access(local_db_script, os.X_OK): + raise PermissionError(f"local-db script is not executable: {local_db_script}") + + cmd = [local_db_script] + args + logger.info(f"Running: {' '.join(cmd)}") + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Command failed: {result.stderr}") + if result.stdout: + logger.info(result.stdout) + return result.stdout + + +def push_to_drive( + config: AppConfig, + version: Optional[str], + overwrite: bool, + local_db_script: Optional[str], + dry_run: bool = False, +) -> None: + # local_db_script: optional helper script that produces stashes; local_db_path points to the data binary + data_path = config.local_db_path + if local_db_script: + # If script provided via CLI, use that to produce a stash; capture any output but do not rely on it + _run_local_db_command(local_db_script, ["stash", version or datetime.now().strftime("%Y-%m-%d")]) + + if not data_path or not os.path.exists(data_path): + raise FileNotFoundError("local_db_path must point to the datastore data binary to upload") + + if not version: + version = datetime.now().strftime("%Y-%m-%d") + + backup_file = f"local-db-{version}.bin" + + if dry_run: + logger.info(f"DRY RUN: would upload {data_path} to Google Drive as {backup_file} in /{config.gdrive_directory}") + return + # Prefer Application Default Credentials (gcloud auth) when available + adc_service = _get_adc_drive_service() + if adc_service: + # Ensure folder exists or create it + # Search for folder + q = f"mimeType='application/vnd.google-apps.folder' and name='{config.gdrive_directory}' and trashed=false" + res = adc_service.files().list(q=q, fields="files(id,name)").execute() + files = res.get("files", []) + if files: + folder_id = files[0]["id"] + else: + file_metadata = {"name": config.gdrive_directory, "mimeType": "application/vnd.google-apps.folder"} + created = adc_service.files().create(body=file_metadata, fields="id").execute() + folder_id = created["id"] + + # Check if a backup with the same name exists + q2 = f"name='{backup_file}' and '{folder_id}' in parents and trashed=false" + res2 = adc_service.files().list(q=q2, fields="files(id,name)").execute() + existing_files = res2.get("files", []) + if existing_files and not overwrite: + raise FileExistsError(f"File {backup_file} already exists in /{config.gdrive_directory}. Use -o to overwrite.") + + media = MediaFileUpload(data_path, resumable=True) + if existing_files: + file_id = existing_files[0]["id"] + adc_service.files().update(fileId=file_id, media_body=media).execute() + else: + file_metadata = {"name": backup_file, "parents": [folder_id]} + adc_service.files().create(body=file_metadata, media_body=media, fields="id").execute() + logger.info(f"Successfully uploaded {backup_file} to Google Drive /{config.gdrive_directory} (ADC)") + return + + # Fallback to pydrive2 interactive flow + drive = _authenticate_drive() + folder_id = _get_or_create_gdrive_folder(drive, config.gdrive_directory) + + existing = drive.ListFile({"q": f"title='{backup_file}' and '{folder_id}' in parents and trashed=false"}).GetList() + if existing: + if overwrite: + logger.info(f"Overwriting existing file: {backup_file}") + file_to_upload = existing[0] + else: + raise FileExistsError(f"File {backup_file} already exists in /{config.gdrive_directory}. Use -o to overwrite.") + else: + file_to_upload = drive.CreateFile({"title": backup_file, "parents": [{"id": folder_id}]}) + + file_to_upload.SetContentFile(data_path) + file_to_upload.Upload() + logger.info(f"Successfully uploaded {backup_file} to Google Drive /{config.gdrive_directory}") + + +def pull_from_drive(config: AppConfig, version: Optional[str], local_db_script: Optional[str], overwrite: bool = True) -> None: + data_path = config.local_db_path + if not data_path: + raise ValueError("local_db_path must be configured in order to restore the database file") + + drive = _authenticate_drive() + folder_id = _get_or_create_gdrive_folder(drive, config.gdrive_directory) + + if version: + backup_file = f"local-db-{version}.bin" + files = drive.ListFile({"q": f"title='{backup_file}' and '{folder_id}' in parents and trashed=false"}).GetList() + if not files: + raise FileNotFoundError(f"No backup found with version: {version}") + file_to_download = files[0] + else: + files = drive.ListFile({ + "q": f"'{folder_id}' in parents and trashed=false and title contains 'local-db-' and title contains '.bin'", + "orderBy": "modifiedDate desc", + "maxResults": 1, + }).GetList() + if not files: + raise FileNotFoundError(f"No backups found in /{config.gdrive_directory} folder") + file_to_download = files[0] + backup_file = file_to_download["title"] + + logger.info(f"Downloading {backup_file} from Google Drive") + # Prefer ADC if available + adc_service = _get_adc_drive_service() + tmp_download = f".download_{backup_file}" + if adc_service: + # find file id + q = f"name='{backup_file}' and '{folder_id}' in parents and trashed=false" + res = adc_service.files().list(q=q, fields="files(id,name)").execute() + files = res.get("files", []) + if not files: + raise FileNotFoundError(f"No backup found with name: {backup_file}") + file_id = files[0]["id"] + request = adc_service.files().get_media(fileId=file_id) + fh = open(tmp_download, "wb") + downloader = MediaIoBaseDownload(fh, request) + done = False + while not done: + status, done = downloader.next_chunk() + fh.close() + else: + # Download into a temporary location then move into place + file_to_download.GetContentFile(tmp_download) + + if os.path.exists(data_path) and not overwrite: + raise FileExistsError(f"Local DB file exists at {data_path}; use overwrite option to replace") + + # ensure parent dir exists + parent = os.path.dirname(data_path) + if parent: + os.makedirs(parent, exist_ok=True) + + shutil.copy2(tmp_download, data_path) + os.remove(tmp_download) + logger.info(f"Restored backup to {data_path}") + + # Optionally run helper restore script if provided + if local_db_script: + _run_local_db_command(local_db_script, ["restore", version or "latest"]) diff --git a/docs/DRIVE_INTERACTIVE.md b/docs/DRIVE_INTERACTIVE.md new file mode 100644 index 0000000..fed57dd --- /dev/null +++ b/docs/DRIVE_INTERACTIVE.md @@ -0,0 +1,63 @@ +# Interactive Google Drive (pydrive2) setup and usage + +This document explains how to run the interactive OAuth flow (pydrive2 LocalWebserverAuth) so you can upload and download your local DB binary to *your* Google Drive using the project's CLI. + +Summary +- The CLI uses pydrive2 for the interactive browser OAuth flow. When you run `cli.py db push` without ADC, pydrive2 will open a browser so you can sign in with your Google account and grant Drive permissions. The files uploaded will belong to the account you authorize. + +Steps + +1) Create an OAuth 2.0 client ID in Google Cloud Console (recommended) + + - Open https://console.cloud.google.com/apis/credentials + - Click "Create credentials" → "OAuth client ID". + - Application type: "Desktop app" (or "Other") is fine for local use. + - Name it (e.g. `local-storage-utils CLI`) and create. + - Download the JSON file and save it as `client_secrets.json` in the project root (the same directory where `cli.py` lives). + +2) Install and activate the Python environment + + ```bash + python -m venv .venv # if you don't already have the project venv + source .venv/bin/activate + pip install -r requirements.txt + ``` + +3) Run the interactive push + + - Dry-run first to confirm the path and filename (no changes to Drive): + + ```bash + .venv/bin/python cli.py db push --dry-run + ``` + + - When ready to upload, run: + + ```bash + .venv/bin/python cli.py db push + ``` + + - pydrive2 will open a browser window to let you authorize. If the browser doesn't open automatically, the command will print a URL you can paste into any browser. Complete the consent flow and return to the terminal. + +4) Verify the upload + + - List backups in the configured Drive folder: + + ```bash + .venv/bin/python cli.py db list + ``` + + - Alternatively, visit https://drive.google.com and inspect the folder named in your `config.yaml` (`gdrive_directory`, default is `datastore`). + +Notes and troubleshooting + +- pydrive2 will save the OAuth token locally after you complete the flow so subsequent runs should not require re-authenticating unless the token expires or is revoked. Watch the CLI output to see which file was created. +- If you get errors about missing `client_secrets.json`, ensure the file you downloaded is named exactly `client_secrets.json` and is in the same working directory where you run the CLI. +- If you prefer not to create a client ID, you can still use ADC (gcloud) or a service account — see other options in the project README. +- If you want the program to request a narrower permission (`drive.file`) instead of full Drive access, say so and I can update the code; you must then re-run the interactive flow or adjust ADC accordingly. + +Security +- Keep `client_secrets.json` and any token files out of source control (add them to `.gitignore`). +- For automated usage (CI), prefer a service account key and `GOOGLE_APPLICATION_CREDENTIALS` environment variable instead of interactive OAuth. + +If you want me to start the interactive push now, say: "run interactive now" and I'll invoke the CLI (it will open the browser / print the auth URL). If you'd rather perform the browser auth yourself first, follow the steps above and then tell me when to proceed. diff --git a/pyproject.toml b/pyproject.toml index 38aba55..c44d6cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "google-cloud-datastore>=2.19.0", "PyYAML>=6.0.1", "typer>=0.12.3", + "pydrive2>=1.20.0", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index e8f4d09..5af2ff0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ google-cloud-datastore>=2.19.0 PyYAML>=6.0.1 -typer>=0.12.3 \ No newline at end of file +typer>=0.12.3 +pydrive2>=1.20.0 \ No newline at end of file diff --git a/tests/test_drive_sync_unit.py b/tests/test_drive_sync_unit.py new file mode 100644 index 0000000..05b5fc1 --- /dev/null +++ b/tests/test_drive_sync_unit.py @@ -0,0 +1,306 @@ +import os +import tempfile +from unittest.mock import Mock, patch +import pytest + +from commands.config import AppConfig +from commands.drive_sync import ( + _get_local_db_path, + _get_or_create_gdrive_folder, + _run_local_db_command, + push_to_drive, + pull_from_drive, + list_backups, +) + + +def test_get_local_db_path_from_override(): + config = AppConfig(local_db_path="/config/path/local-db") + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_path = tmp.name + try: + result = _get_local_db_path(config, tmp_path) + assert result == tmp_path + finally: + os.unlink(tmp_path) + + +def test_get_local_db_path_from_config(): + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_path = tmp.name + try: + config = AppConfig(local_db_path=tmp_path) + result = _get_local_db_path(config, None) + assert result == tmp_path + finally: + os.unlink(tmp_path) + + +def test_get_local_db_path_missing_raises(): + config = AppConfig() + with pytest.raises(ValueError, match="local-db path must be provided"): + _get_local_db_path(config, None) + + +def test_get_local_db_path_not_found_raises(): + config = AppConfig(local_db_path="/nonexistent/local-db") + with pytest.raises(FileNotFoundError, match="local-db binary not found"): + _get_local_db_path(config, None) + + +def test_get_or_create_datastore_folder_existing(): + mock_drive = Mock() + mock_file_list = Mock() + mock_file_list.GetList.return_value = [{"id": "folder-123", "title": "datastore"}] + mock_drive.ListFile.return_value = mock_file_list + + folder_id = _get_or_create_gdrive_folder(mock_drive, "datastore") + assert folder_id == "folder-123" + mock_drive.CreateFile.assert_not_called() + + +def test_get_or_create_datastore_folder_creates_new(): + mock_drive = Mock() + mock_file_list = Mock() + mock_file_list.GetList.return_value = [] + mock_drive.ListFile.return_value = mock_file_list + + mock_folder = Mock() + mock_folder.__getitem__ = Mock(return_value="new-folder-123") + mock_drive.CreateFile.return_value = mock_folder + + folder_id = _get_or_create_gdrive_folder(mock_drive, "datastore") + assert folder_id == "new-folder-123" + mock_drive.CreateFile.assert_called_once() + mock_folder.Upload.assert_called_once() + + +def test_run_local_db_command_success(): + with tempfile.NamedTemporaryFile(delete=False, mode="w", suffix=".sh") as tmp: + tmp.write("#!/bin/bash\necho 'Success'\n") + tmp_path = tmp.name + os.chmod(tmp_path, 0o755) + try: + out = _run_local_db_command(tmp_path, ["arg1"]) + assert "Success" in out + finally: + os.unlink(tmp_path) + + +def test_run_local_db_command_failure(): + with tempfile.NamedTemporaryFile(delete=False, mode="w", suffix=".sh") as tmp: + tmp.write("#!/bin/bash\nexit 1\n") + tmp_path = tmp.name + os.chmod(tmp_path, 0o755) + try: + with pytest.raises(RuntimeError, match="Command failed"): + _run_local_db_command(tmp_path, ["arg1"]) + finally: + os.unlink(tmp_path) + + +@patch("commands.drive_sync._authenticate_drive") +@patch("commands.drive_sync.os.path.exists") +def test_push_to_drive_with_version(mock_exists, mock_auth): + mock_exists.return_value = True + mock_drive = Mock() + mock_auth.return_value = mock_drive + + mock_folder_list = Mock() + mock_folder_list.GetList.return_value = [{"id": "folder-123"}] + + mock_file_list = Mock() + mock_file_list.GetList.return_value = [] + + mock_drive.ListFile.side_effect = [mock_folder_list, mock_file_list] + + mock_file = Mock() + mock_file.__getitem__ = Mock(return_value="folder-123") + mock_drive.CreateFile.return_value = mock_file + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_path = tmp.name + try: + config = AppConfig(local_db_path=tmp_path) + push_to_drive(config, "2024-01-01", False, None) + + # Should upload the data file located at local_db_path + mock_file.SetContentFile.assert_called_once_with(tmp_path) + mock_file.Upload.assert_called_once() + finally: + os.unlink(tmp_path) + + +@patch("commands.drive_sync._authenticate_drive") +@patch("commands.drive_sync.os.path.exists") +@patch("commands.drive_sync.datetime") +def test_push_to_drive_without_version(mock_datetime, mock_exists, mock_auth): + mock_now = Mock() + mock_now.strftime.return_value = "2024-12-25" + mock_datetime.now.return_value = mock_now + mock_exists.return_value = True + + mock_drive = Mock() + mock_auth.return_value = mock_drive + + mock_folder_list = Mock() + mock_folder_list.GetList.return_value = [{"id": "folder-123"}] + + mock_file_list = Mock() + mock_file_list.GetList.return_value = [] + + mock_drive.ListFile.side_effect = [mock_folder_list, mock_file_list] + + mock_file = Mock() + mock_file.__getitem__ = Mock(return_value="folder-123") + mock_drive.CreateFile.return_value = mock_file + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_path = tmp.name + try: + config = AppConfig(local_db_path=tmp_path) + push_to_drive(config, None, False, None) + + mock_file.SetContentFile.assert_called_once_with(tmp_path) + finally: + os.unlink(tmp_path) + + +@patch("commands.drive_sync._authenticate_drive") +@patch("commands.drive_sync.os.path.exists") +def test_push_to_drive_overwrite(mock_exists, mock_auth): + mock_exists.return_value = True + mock_drive = Mock() + mock_auth.return_value = mock_drive + + existing_file = Mock() + existing_file.__getitem__ = Mock(return_value="existing-file-123") + mock_file_list = Mock() + mock_file_list.GetList.return_value = [existing_file] + + mock_folder_list = Mock() + mock_folder_list.GetList.return_value = [{"id": "folder-123"}] + + mock_drive.ListFile.side_effect = [mock_folder_list, mock_file_list] + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_path = tmp.name + try: + config = AppConfig(local_db_path=tmp_path) + push_to_drive(config, "2024-01-01", True, None) + + existing_file.SetContentFile.assert_called_once_with(tmp_path) + existing_file.Upload.assert_called_once() + finally: + os.unlink(tmp_path) + + +@patch("commands.drive_sync._authenticate_drive") +@patch("commands.drive_sync.os.path.exists") +def test_push_to_drive_no_overwrite_raises(mock_exists, mock_auth): + mock_exists.return_value = True + mock_drive = Mock() + mock_auth.return_value = mock_drive + + existing_file = Mock() + mock_file_list = Mock() + mock_file_list.GetList.return_value = [existing_file] + + mock_folder_list = Mock() + mock_folder_list.GetList.return_value = [{"id": "folder-123"}] + + mock_drive.ListFile.side_effect = [mock_folder_list, mock_file_list] + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_path = tmp.name + try: + config = AppConfig(local_db_path=tmp_path) + with pytest.raises(FileExistsError, match="already exists"): + push_to_drive(config, "2024-01-01", False, None) + finally: + os.unlink(tmp_path) + + +@patch("commands.drive_sync._authenticate_drive") +@patch("commands.drive_sync.shutil.copy2") +@patch("commands.drive_sync.os.remove") +def test_pull_from_drive_with_version(mock_remove, mock_copy, mock_auth): + mock_drive = Mock() + mock_auth.return_value = mock_drive + + mock_file = Mock() + mock_file.__getitem__ = Mock(return_value="local-db-2024-01-01.bin") + mock_file_list = Mock() + mock_file_list.GetList.return_value = [mock_file] + + mock_folder_list = Mock() + mock_folder_list.GetList.return_value = [{"id": "folder-123"}] + + mock_drive.ListFile.side_effect = [mock_folder_list, mock_file_list] + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_path = tmp.name + try: + config = AppConfig(local_db_path=tmp_path) + pull_from_drive(config, "2024-01-01", None) + + # The drive file should be asked to download into a temp location + mock_file.GetContentFile.assert_called_once() + # The downloaded temp should be copied into local_db_path + mock_copy.assert_called_once() + mock_remove.assert_called_once() + finally: + os.unlink(tmp_path) + + +@patch("commands.drive_sync._authenticate_drive") +@patch("commands.drive_sync.shutil.copy2") +@patch("commands.drive_sync.os.remove") +def test_pull_from_drive_without_version(mock_remove, mock_copy, mock_auth): + mock_drive = Mock() + mock_auth.return_value = mock_drive + + mock_file = Mock() + mock_file.__getitem__ = Mock(return_value="local-db-latest.bin") + mock_file_list = Mock() + mock_file_list.GetList.return_value = [mock_file] + + mock_folder_list = Mock() + mock_folder_list.GetList.return_value = [{"id": "folder-123"}] + + mock_drive.ListFile.side_effect = [mock_folder_list, mock_file_list] + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_path = tmp.name + try: + config = AppConfig(local_db_path=tmp_path) + pull_from_drive(config, None, None) + + mock_file.GetContentFile.assert_called_once() + mock_copy.assert_called_once() + mock_remove.assert_called_once() + finally: + os.unlink(tmp_path) + + +@patch("commands.drive_sync._authenticate_drive") +def test_pull_from_drive_version_not_found(mock_auth): + mock_drive = Mock() + mock_auth.return_value = mock_drive + + mock_file_list = Mock() + mock_file_list.GetList.return_value = [] + + mock_folder_list = Mock() + mock_folder_list.GetList.return_value = [{"id": "folder-123"}] + + mock_drive.ListFile.side_effect = [mock_folder_list, mock_file_list] + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + tmp_path = tmp.name + try: + config = AppConfig(local_db_path=tmp_path) + with pytest.raises(FileNotFoundError, match="No backup found"): + pull_from_drive(config, "nonexistent", None) + finally: + os.unlink(tmp_path)