diff --git a/pyproject.toml b/pyproject.toml index ec4769d..a3b06b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,9 +5,7 @@ build-backend = "uv_build" [project] name = "contraqctor" description = "A library for managing data contracts and quality control in behavioral datasets." -authors = [ - { name = "Bruno Cruz", email = "bruno.cruz@alleninstitute.org" }, -] +authors = [{ name = "Bruno Cruz", email = "bruno.cruz@alleninstitute.org" }] requires-python = ">=3.11" license = "MIT" @@ -42,13 +40,7 @@ Changelog = "https://github.com/AllenNeuralDynamics/contraqctor/releases" [dependency-groups] -dev = [ - 'codespell', - 'pytest', - 'pytest-cov', - 'ruff', - 'interrogate' -] +dev = ['codespell', 'pytest', 'pytest-cov', 'ruff', 'interrogate'] docs = [ 'mkdocs', @@ -81,10 +73,19 @@ testpaths = ["tests"] python_files = ["test_*.py"] python_classes = ["Test*"] python_functions = ["test_*"] +env = ["MPLBACKEND=Agg"] [tool.interrogate] ignore-init-method = true ignore-magic = true ignore_module = true fail-under = 100 -exclude = ["__init__.py", "tests", "docs", "build", "setup.py", "examples", "site"] +exclude = [ + "__init__.py", + "tests", + "docs", + "build", + "setup.py", + "examples", + "site", +] diff --git a/src/contraqctor/_typing.py b/src/contraqctor/_typing.py index 00b98f4..4de0fdc 100644 --- a/src/contraqctor/_typing.py +++ b/src/contraqctor/_typing.py @@ -1,8 +1,13 @@ -from typing import Any, Generic, Protocol, TypeAlias, TypeVar, Union, cast, final +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeAlias, TypeVar, Union, cast, final + +if TYPE_CHECKING: + from contraqctor.contract.base import DataStream +else: + DataStream = Any # type: ignore # Type variables -TData = TypeVar("TData", bound=Union[Any, "_UnsetData"]) -"""TypeVar: Type variable bound to Union[Any, "_UnsetData"] for data types.""" +TData = TypeVar("TData", bound=Union[Any, "_UnsetData", "ErrorOnLoad"]) +"""TypeVar: Type variable bound to Union[Any, "_UnsetData", "ErrorOnLoad"] for data types.""" TReaderParams = TypeVar("TReaderParams", contravariant=True) """TypeVar: Contravariant type variable for reader parameters.""" @@ -157,3 +162,42 @@ def is_unset(obj: Any) -> bool: True if the object is an unset sentinel value, False otherwise. """ return (obj is UnsetReader) or (obj is UnsetParams) or (obj is UnsetData) + + +@final +class ErrorOnLoad: + """A class representing data that failed to load due to an error. + + Attributes: + datastream: The data stream that failed to load. + error: The exception that occurred during data loading. + + This class is used to encapsulate information about data loading failures, + allowing for graceful handling of errors in data processing workflows. + """ + + def __init__(self, data_stream: "DataStream", exception: Exception | None = None): + self._data_stream = data_stream + self._exception = exception + + @property + def data_stream(self) -> "DataStream": + """The data stream that failed to load.""" + return self._data_stream + + @property + def exception(self) -> Exception | None: + """The exception that occurred during data loading, if any.""" + return self._exception + + def __repr__(self): + return f"" + + def raise_from_error(self): + """Raises the stored error if it exists. + + Raises: + The stored exception if it is not None. + """ + if self.exception is not None: + raise self.exception diff --git a/src/contraqctor/contract/base.py b/src/contraqctor/contract/base.py index db40f22..8020961 100644 --- a/src/contraqctor/contract/base.py +++ b/src/contraqctor/contract/base.py @@ -1,7 +1,20 @@ import abc import dataclasses import os -from typing import Any, ClassVar, Dict, Generator, Generic, List, Optional, Protocol, Self, TypeVar, runtime_checkable +from typing import ( + Any, + ClassVar, + Dict, + Generator, + Generic, + List, + Optional, + Protocol, + Self, + TypeVar, + cast, + runtime_checkable, +) from semver import Version from typing_extensions import override @@ -201,7 +214,16 @@ def has_data(self) -> bool: Returns: bool: True if data has been loaded, False otherwise. """ - return not _typing.is_unset(self._data) + return not (_typing.is_unset(self._data) or self.has_error) + + @property + def has_error(self) -> bool: + """Check if the data stream encountered an error during loading. + + Returns: + bool: True if an error occurred, False otherwise. + """ + return isinstance(self._data, _typing.ErrorOnLoad) @property def data(self) -> _typing.TData: @@ -213,9 +235,22 @@ def data(self) -> _typing.TData: Raises: ValueError: If data has not been loaded yet. """ + if self.has_error: + cast(_typing.ErrorOnLoad, self._data).raise_from_error() if not self.has_data: raise ValueError("Data has not been loaded yet.") - return self._data + return cast(_typing.TData, self._data) + + def clear(self) -> Self: + """Clear the loaded data from the data stream. + + Resets the data to an unset state, allowing for reloading. + + Returns: + Self: The data stream instance for method chaining. + """ + self._data = _typing.UnsetData + return self def load(self) -> Self: """Load data into the data stream. @@ -239,7 +274,10 @@ def load(self) -> Self: print(f"Loaded {len(df)} rows") ``` """ - self._data = self.read() + try: + self._data = self.read() + except Exception as e: # pylint: disable=broad-except + self._data = _typing.ErrorOnLoad(self, exception=e) return self def __str__(self): @@ -266,9 +304,27 @@ def __iter__(self) -> Generator["DataStream", None, None]: Yields: DataStream: Child data streams (none for base DataStream). """ - yield + return + yield # This line is unreachable but needed for the generator type + + def collect_errors(self) -> List[_typing.ErrorOnLoad]: + """Collect all errors from this stream and its children. - def load_all(self, strict: bool = False) -> list[tuple["DataStream", Exception], None, None]: + Performs a depth-first traversal to gather all ErrorOnLoad instances. + + Returns: + List[ErrorOnLoad]: List of all errors raised on load encountered in the hierarchy. + """ + errors = [] + if self.has_error: + errors.append(cast(_typing.ErrorOnLoad, self._data)) + for stream in self: + if stream is None: + continue + errors.extend(stream.collect_errors()) + return errors + + def load_all(self, strict: bool = False) -> Self: """Recursively load this data stream and all child streams. Performs depth-first traversal to load all streams in the hierarchy. @@ -293,17 +349,13 @@ def load_all(self, strict: bool = False) -> list[tuple["DataStream", Exception], ``` """ self.load() - exceptions = [] for stream in self: if stream is None: continue - try: - exceptions += stream.load_all(strict=strict) - except Exception as e: - if strict: - raise e - exceptions.append((stream, e)) - return exceptions + stream.load_all(strict=strict) + if stream.has_error and strict: + cast(_typing.ErrorOnLoad, stream.data).raise_from_error() + return self TDataStream = TypeVar("TDataStream", bound=DataStream[Any, Any]) @@ -411,7 +463,7 @@ def at(self) -> _At[TDataStream]: return self._at @override - def load(self): + def load(self) -> Self: """Load data for this collection. Overrides the base method to add validation that loaded data is a list of DataStreams. diff --git a/src/contraqctor/qc/contract.py b/src/contraqctor/qc/contract.py index 1db5057..d7e9143 100644 --- a/src/contraqctor/qc/contract.py +++ b/src/contraqctor/qc/contract.py @@ -1,5 +1,6 @@ import typing as t +from .._typing import ErrorOnLoad from ..contract.base import DataStream from .base import Suite @@ -36,9 +37,7 @@ class ContractTestSuite(Suite): ``` """ - def __init__( - self, loading_errors: list[tuple[DataStream, Exception]], exclude: t.Optional[list[DataStream]] = None - ): + def __init__(self, loading_errors: list[ErrorOnLoad], exclude: t.Optional[list[DataStream]] = None): """Initialize the contract test suite. Args: @@ -51,9 +50,9 @@ def __init__( def test_has_errors_on_load(self): """Check if any non-excluded data streams had loading errors.""" - errors = [(ds, err) for ds, err in self.loading_errors if ds not in self.exclude] + errors = [err for err in self.loading_errors if err.data_stream not in self.exclude] if errors: - str_errors = "\n".join([f"{ds.resolved_name}" for ds, _ in errors]) + str_errors = "\n".join([f"{err.data_stream.resolved_name}" for err in errors]) return self.fail_test( None, f"The following DataStreams raised errors on load: \n {str_errors}", @@ -64,7 +63,7 @@ def test_has_errors_on_load(self): def test_has_excluded_as_warnings(self): """Check if any excluded data streams had loading errors and report as warnings.""" - warnings = [(ds, err) for ds, err in self.loading_errors if ds in self.exclude] + warnings = [err for err in self.loading_errors if err.data_stream in self.exclude] if warnings: return self.warn_test( None, diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_contract/conftest.py b/tests/conftest.py similarity index 98% rename from tests/test_contract/conftest.py rename to tests/conftest.py index 7952de2..96d64f8 100644 --- a/tests/test_contract/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import json +import os import tempfile import typing as t from pathlib import Path @@ -7,6 +8,7 @@ import pytest from pydantic import BaseModel +os.environ["MPLBACKEND"] = "Agg" from contraqctor.contract.base import ( DataStream, FilePathBaseParam, diff --git a/tests/test_contract/test_core.py b/tests/test_contract/test_core.py index 4a93730..a43f9c4 100644 --- a/tests/test_contract/test_core.py +++ b/tests/test_contract/test_core.py @@ -1,10 +1,9 @@ import pytest +from conftest import SimpleDataStream, SimpleParams from contraqctor import _typing from contraqctor.contract.base import DataStreamCollection -from .conftest import SimpleDataStream, SimpleParams - class TestDataStream: """Tests for the DataStream class.""" @@ -76,6 +75,19 @@ def test_invalid_name(self, text_file): name="test::invalid", description="Test stream", reader_params=SimpleParams(path=text_file) ) + def test_clear_data(self, text_file): + """Test clearing loaded data.""" + stream = SimpleDataStream(name="test", reader_params=SimpleParams(path=text_file)) + + stream.load() + assert stream.has_data + + stream.clear() + assert not stream.has_data + + with pytest.raises(ValueError): + _ = stream.data # Accessing data after clearing should raise ValueError + class TestDataStreamCollection: """Tests for the DataStreamCollection anonymous class.""" @@ -225,7 +237,7 @@ def test_load_all_success(self, text_file): collection = DataStreamCollection(name="collection", data_streams=[stream1, stream2]) result = collection.load_all() - assert result == [] # No exceptions + assert result.collect_errors() == [] assert stream1.has_data assert stream2.has_data @@ -239,14 +251,17 @@ def test_load_all_with_exception(self, text_file, temp_dir): collection = DataStreamCollection(name="collection", data_streams=[stream1, stream2]) result = collection.load_all() - - assert len(result) == 1 - assert result[0][0] == stream2 - assert isinstance(result[0][1], FileNotFoundError) + errors = result.collect_errors() + assert len(errors) == 1 + assert errors[0].data_stream == stream2 + assert isinstance(errors[0].exception, FileNotFoundError) assert stream1.has_data assert not stream2.has_data + with pytest.raises(FileNotFoundError): + raise errors[0].exception + def test_load_all_strict(self, text_file, temp_dir): """Test load_all with strict=True.""" stream1 = SimpleDataStream(name="stream1", reader_params=SimpleParams(path=text_file)) diff --git a/tests/test_contract/test_json.py b/tests/test_contract/test_json.py index 0cf23f4..83c8446 100644 --- a/tests/test_contract/test_json.py +++ b/tests/test_contract/test_json.py @@ -1,6 +1,7 @@ import json import pandas as pd +from conftest import MockModel from contraqctor.contract.json import ( Json, @@ -12,8 +13,6 @@ PydanticModelParams, ) -from .conftest import MockModel - class TestJson: """Tests for the Json class.""" diff --git a/tests/test_qc/test_contract.py b/tests/test_qc/test_qc_contract.py similarity index 70% rename from tests/test_qc/test_contract.py rename to tests/test_qc/test_qc_contract.py index 8ec8339..1a7bd80 100644 --- a/tests/test_qc/test_contract.py +++ b/tests/test_qc/test_qc_contract.py @@ -1,38 +1,37 @@ import pytest +from conftest import SimpleDataStream, SimpleParams -from contraqctor.contract.base import DataStream +from contraqctor._typing import ErrorOnLoad +from contraqctor.contract.base import DataStreamCollection from contraqctor.qc.base import Status from contraqctor.qc.contract import ContractTestSuite -class MockDataStream(DataStream): - """Mock DataStream class for testing.""" +def raise_value_error(*args, **kwargs): + raise ValueError("Simulated load error") - def __init__(self, name="test"): - super().__init__(name=name) - self._resolved_name = name - @property - def resolved_name(self): - return self._resolved_name +def raise_io_error(*args, **kwargs): + raise IOError("Simulated load error") @pytest.fixture -def loading_errors(): - ds1 = MockDataStream(name="stream1") - ds2 = MockDataStream(name="stream2") - ds3 = MockDataStream(name="stream3") +def loading_errors(text_file) -> list[ErrorOnLoad]: + stream1 = SimpleDataStream(name="stream1", reader_params=SimpleParams(path=text_file)) + stream1._reader = raise_value_error - err1 = ValueError("Error loading stream1") - err2 = FileNotFoundError("File not found for stream2") - err3 = RuntimeError("Error in stream3") + stream2 = SimpleDataStream(name="stream2", reader_params=SimpleParams(path=text_file)) + stream2._reader = raise_io_error - return [(ds1, err1), (ds2, err2), (ds3, err3)] + collection = DataStreamCollection(name="collection", data_streams=[stream1, stream2]) + + collection.load_all() + return collection.collect_errors() @pytest.fixture def excluded_streams(loading_errors): - return [loading_errors[0][0]] + return [loading_errors[0].data_stream] class TestContractTestSuite: @@ -54,7 +53,9 @@ def test_has_errors_on_load_with_errors(self, loading_errors): result = suite.test_has_errors_on_load() assert result.status == Status.FAILED + assert result.message is not None assert "raised errors on load" in result.message + assert result.context is not None assert "errors" in result.context assert len(result.context["errors"]) == len(loading_errors) @@ -64,6 +65,7 @@ def test_has_errors_on_load_no_errors(self): result = suite.test_has_errors_on_load() assert result.status == Status.PASSED + assert result.message is not None assert "All DataStreams loaded successfully" in result.message def test_has_errors_on_load_with_excludes(self, loading_errors, excluded_streams): @@ -72,11 +74,12 @@ def test_has_errors_on_load_with_excludes(self, loading_errors, excluded_streams result = suite.test_has_errors_on_load() assert result.status == Status.FAILED + assert result.context is not None assert "errors" in result.context assert len(result.context["errors"]) == len(loading_errors) - len(excluded_streams) excluded_names = [ds.resolved_name for ds in excluded_streams] - for ds, _ in result.context["errors"]: - assert ds.resolved_name not in excluded_names + for err in result.context["errors"]: + assert err.data_stream.resolved_name not in excluded_names def test_has_excluded_as_warnings_with_excludes(self, loading_errors, excluded_streams): """Test test_has_excluded_as_warnings method with excluded streams.""" @@ -84,10 +87,11 @@ def test_has_excluded_as_warnings_with_excludes(self, loading_errors, excluded_s result = suite.test_has_excluded_as_warnings() assert result.status == Status.WARNING + assert result.context is not None assert "warnings" in result.context assert len(result.context["warnings"]) == len(excluded_streams) - for ds, _ in result.context["warnings"]: - assert ds in excluded_streams + for err in result.context["warnings"]: + assert err.data_stream in excluded_streams def test_has_excluded_as_warnings_no_excludes(self, loading_errors): """Test test_has_excluded_as_warnings method with no excluded streams.""" @@ -95,6 +99,7 @@ def test_has_excluded_as_warnings_no_excludes(self, loading_errors): result = suite.test_has_excluded_as_warnings() assert result.status == Status.PASSED + assert result.message is not None assert "No excluded DataStreams raised errors" in result.message def test_has_excluded_as_warnings_empty_errors(self, excluded_streams): @@ -103,4 +108,5 @@ def test_has_excluded_as_warnings_empty_errors(self, excluded_streams): result = suite.test_has_excluded_as_warnings() assert result.status == Status.PASSED + assert result.message is not None assert "No excluded DataStreams raised errors" in result.message