diff --git a/oc4ids_datastore_pipeline/database.py b/oc4ids_datastore_pipeline/database.py index f160c1b..f699b56 100644 --- a/oc4ids_datastore_pipeline/database.py +++ b/oc4ids_datastore_pipeline/database.py @@ -8,6 +8,8 @@ Engine, String, create_engine, + delete, + select, ) from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column @@ -45,3 +47,16 @@ def save_dataset(dataset: Dataset) -> None: with Session(get_engine()) as session: session.merge(dataset) session.commit() + + +def delete_dataset(dataset_id: str) -> None: + with Session(get_engine()) as session: + session.execute(delete(Dataset).where(Dataset.dataset_id == dataset_id)) + session.commit() + + +def get_dataset_ids() -> list[str]: + with Session(get_engine()) as session: + return [ + dataset_id for dataset_id in session.scalars(select(Dataset.dataset_id)) + ] diff --git a/oc4ids_datastore_pipeline/pipeline.py b/oc4ids_datastore_pipeline/pipeline.py index cca5fbd..aabe607 100644 --- a/oc4ids_datastore_pipeline/pipeline.py +++ b/oc4ids_datastore_pipeline/pipeline.py @@ -9,7 +9,12 @@ import requests from libcoveoc4ids.api import oc4ids_json_output -from oc4ids_datastore_pipeline.database import Dataset, save_dataset +from oc4ids_datastore_pipeline.database import ( + Dataset, + delete_dataset, + get_dataset_ids, + save_dataset, +) from oc4ids_datastore_pipeline.registry import ( fetch_registered_datasets, get_license_name_from_url, @@ -122,11 +127,20 @@ def process_dataset(dataset_name: str, dataset_url: str) -> None: logger.warning(f"Failed to process dataset {dataset_name} with error {e}") -def process_datasets() -> None: +def process_deleted_datasets(registered_datasets: dict[str, str]) -> None: + stored_datasets = get_dataset_ids() + deleted_datasets = stored_datasets - registered_datasets.keys() + for dataset_id in deleted_datasets: + logger.info(f"Dataset {dataset_id} is no longer in the registry, deleting") + delete_dataset(dataset_id) + + +def process_registry() -> None: registered_datasets = fetch_registered_datasets() + process_deleted_datasets(registered_datasets) for name, url in registered_datasets.items(): process_dataset(name, url) def run() -> None: - process_datasets() + process_registry() diff --git a/oc4ids_datastore_pipeline/registry.py b/oc4ids_datastore_pipeline/registry.py index 72fa14d..d921e39 100644 --- a/oc4ids_datastore_pipeline/registry.py +++ b/oc4ids_datastore_pipeline/registry.py @@ -22,9 +22,13 @@ def fetch_registered_datasets() -> dict[str, str]: } registered_datasets_count = len(registered_datasets) logger.info(f"Fetched URLs for {registered_datasets_count} datasets") - return registered_datasets except Exception as e: raise Exception("Failed to fetch datasets list from registry", e) + if registered_datasets_count < 1: + raise Exception( + "Zero datasets returned from registry, likely an upstream error" + ) + return registered_datasets def fetch_license_mappings() -> dict[str, str]: diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..e8f2926 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,54 @@ +import datetime +from typing import Any, Generator + +import pytest +from pytest_mock import MockerFixture +from sqlalchemy import create_engine + +from oc4ids_datastore_pipeline.database import ( + Base, + Dataset, + delete_dataset, + get_dataset_ids, + save_dataset, +) + + +@pytest.fixture(autouse=True) +def before_and_after_each(mocker: MockerFixture) -> Generator[Any, Any, Any]: + engine = create_engine("sqlite:///:memory:") + patch_get_engine = mocker.patch("oc4ids_datastore_pipeline.database.get_engine") + patch_get_engine.return_value = engine + Base.metadata.create_all(engine) + yield + engine.dispose() + + +def test_save_dataset() -> None: + dataset = Dataset( + dataset_id="test_dataset", + source_url="https://test_dataset.json", + publisher_name="test_publisher", + json_url="data/test_dataset.json", + updated_at=datetime.datetime.now(datetime.UTC), + ) + save_dataset(dataset) + + assert get_dataset_ids() == ["test_dataset"] + + +def test_delete_dataset() -> None: + dataset = Dataset( + dataset_id="test_dataset", + source_url="https://test_dataset.json", + publisher_name="test_publisher", + json_url="data/test_dataset.json", + updated_at=datetime.datetime.now(datetime.UTC), + ) + save_dataset(dataset) + + assert get_dataset_ids() == ["test_dataset"] + + delete_dataset("test_dataset") + + assert get_dataset_ids() == [] diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index f914fe6..4a1718d 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -8,6 +8,7 @@ from oc4ids_datastore_pipeline.pipeline import ( download_json, process_dataset, + process_deleted_datasets, transform_to_csv_and_xlsx, validate_json, write_json_to_file, @@ -102,6 +103,21 @@ def test_transform_to_csv_and_xlsx_catches_exception(mocker: MockerFixture) -> N assert xlsx_path is None +def test_process_deleted_datasets(mocker: MockerFixture) -> None: + patch_get_dataset_ids = mocker.patch( + "oc4ids_datastore_pipeline.pipeline.get_dataset_ids" + ) + patch_get_dataset_ids.return_value = ["old_dataset", "test_dataset"] + patch_delete_dataset = mocker.patch( + "oc4ids_datastore_pipeline.pipeline.delete_dataset" + ) + + registered_datasets = {"test_dataset": "https://test_dataset.json"} + process_deleted_datasets(registered_datasets) + + patch_delete_dataset.assert_called_once_with("old_dataset") + + def test_process_dataset_catches_exception(mocker: MockerFixture) -> None: patch_download_json = mocker.patch( "oc4ids_datastore_pipeline.pipeline.download_json" diff --git a/tests/test_registry.py b/tests/test_registry.py index 6398524..77f9eb2 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -38,6 +38,20 @@ def test_fetch_registered_datasets_raises_failure_exception( assert "Mocked exception" in str(exc_info.value) +def test_fetch_registered_datasets_raises_exception_when_no_datasets( + mocker: MockerFixture, +) -> None: + mock_response = MagicMock() + mock_response.json.return_value = {"records": {}} + patch_get = mocker.patch("oc4ids_datastore_pipeline.pipeline.requests.get") + patch_get.return_value = mock_response + + with pytest.raises(Exception) as exc_info: + fetch_registered_datasets() + + assert "Zero datasets returned from registry" in str(exc_info.value) + + def test_fetch_license_mappings(mocker: MockerFixture) -> None: mock_response = MagicMock() mock_response.json.return_value = {