Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ build-backend = "uv_build"
[project]
name = "contraqctor"
description = "A library for managing data contracts and quality control in behavioral datasets."
authors = [
{ name = "Bruno Cruz", email = "[email protected]" },
]
authors = [{ name = "Bruno Cruz", email = "[email protected]" }]
requires-python = ">=3.11"
license = "MIT"

Expand Down Expand Up @@ -42,13 +40,7 @@ Changelog = "https://github.com/AllenNeuralDynamics/contraqctor/releases"

[dependency-groups]

dev = [
'codespell',
'pytest',
'pytest-cov',
'ruff',
'interrogate'
]
dev = ['codespell', 'pytest', 'pytest-cov', 'ruff', 'interrogate']

docs = [
'mkdocs',
Expand Down Expand Up @@ -81,10 +73,19 @@ testpaths = ["tests"]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
env = ["MPLBACKEND=Agg"]

[tool.interrogate]
ignore-init-method = true
ignore-magic = true
ignore_module = true
fail-under = 100
exclude = ["__init__.py", "tests", "docs", "build", "setup.py", "examples", "site"]
exclude = [
"__init__.py",
"tests",
"docs",
"build",
"setup.py",
"examples",
"site",
]
50 changes: 47 additions & 3 deletions src/contraqctor/_typing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import Any, Generic, Protocol, TypeAlias, TypeVar, Union, cast, final
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeAlias, TypeVar, Union, cast, final

if TYPE_CHECKING:
from contraqctor.contract.base import DataStream
else:
DataStream = Any # type: ignore

# Type variables
TData = TypeVar("TData", bound=Union[Any, "_UnsetData"])
"""TypeVar: Type variable bound to Union[Any, "_UnsetData"] for data types."""
TData = TypeVar("TData", bound=Union[Any, "_UnsetData", "ErrorOnLoad"])
"""TypeVar: Type variable bound to Union[Any, "_UnsetData", "ErrorOnLoad"] for data types."""

TReaderParams = TypeVar("TReaderParams", contravariant=True)
"""TypeVar: Contravariant type variable for reader parameters."""
Expand Down Expand Up @@ -157,3 +162,42 @@ def is_unset(obj: Any) -> bool:
True if the object is an unset sentinel value, False otherwise.
"""
return (obj is UnsetReader) or (obj is UnsetParams) or (obj is UnsetData)


@final
class ErrorOnLoad:
"""A class representing data that failed to load due to an error.

Attributes:
datastream: The data stream that failed to load.
error: The exception that occurred during data loading.

This class is used to encapsulate information about data loading failures,
allowing for graceful handling of errors in data processing workflows.
"""

def __init__(self, data_stream: "DataStream", exception: Exception | None = None):
self._data_stream = data_stream
self._exception = exception

@property
def data_stream(self) -> "DataStream":
"""The data stream that failed to load."""
return self._data_stream

@property
def exception(self) -> Exception | None:
"""The exception that occurred during data loading, if any."""
return self._exception

def __repr__(self):
return f"<ErrorData stream={self.data_stream} error={self.exception}>"

def raise_from_error(self):
"""Raises the stored error if it exists.

Raises:
The stored exception if it is not None.
"""
if self.exception is not None:
raise self.exception
82 changes: 67 additions & 15 deletions src/contraqctor/contract/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
import abc
import dataclasses
import os
from typing import Any, ClassVar, Dict, Generator, Generic, List, Optional, Protocol, Self, TypeVar, runtime_checkable
from typing import (
Any,
ClassVar,
Dict,
Generator,
Generic,
List,
Optional,
Protocol,
Self,
TypeVar,
cast,
runtime_checkable,
)

from semver import Version
from typing_extensions import override
Expand Down Expand Up @@ -201,7 +214,16 @@ def has_data(self) -> bool:
Returns:
bool: True if data has been loaded, False otherwise.
"""
return not _typing.is_unset(self._data)
return not (_typing.is_unset(self._data) or self.has_error)

@property
def has_error(self) -> bool:
"""Check if the data stream encountered an error during loading.

Returns:
bool: True if an error occurred, False otherwise.
"""
return isinstance(self._data, _typing.ErrorOnLoad)

@property
def data(self) -> _typing.TData:
Expand All @@ -213,9 +235,22 @@ def data(self) -> _typing.TData:
Raises:
ValueError: If data has not been loaded yet.
"""
if self.has_error:
cast(_typing.ErrorOnLoad, self._data).raise_from_error()
if not self.has_data:
raise ValueError("Data has not been loaded yet.")
return self._data
return cast(_typing.TData, self._data)

def clear(self) -> Self:
"""Clear the loaded data from the data stream.

Resets the data to an unset state, allowing for reloading.

Returns:
Self: The data stream instance for method chaining.
"""
self._data = _typing.UnsetData
return self

def load(self) -> Self:
"""Load data into the data stream.
Expand All @@ -239,7 +274,10 @@ def load(self) -> Self:
print(f"Loaded {len(df)} rows")
```
"""
self._data = self.read()
try:
self._data = self.read()
except Exception as e: # pylint: disable=broad-except
self._data = _typing.ErrorOnLoad(self, exception=e)
return self

def __str__(self):
Expand All @@ -266,9 +304,27 @@ def __iter__(self) -> Generator["DataStream", None, None]:
Yields:
DataStream: Child data streams (none for base DataStream).
"""
yield
return
yield # This line is unreachable but needed for the generator type

def collect_errors(self) -> List[_typing.ErrorOnLoad]:
"""Collect all errors from this stream and its children.

def load_all(self, strict: bool = False) -> list[tuple["DataStream", Exception], None, None]:
Performs a depth-first traversal to gather all ErrorOnLoad instances.

Returns:
List[ErrorOnLoad]: List of all errors raised on load encountered in the hierarchy.
"""
errors = []
if self.has_error:
errors.append(cast(_typing.ErrorOnLoad, self._data))
for stream in self:
if stream is None:
continue
errors.extend(stream.collect_errors())
return errors

def load_all(self, strict: bool = False) -> Self:
"""Recursively load this data stream and all child streams.

Performs depth-first traversal to load all streams in the hierarchy.
Expand All @@ -293,17 +349,13 @@ def load_all(self, strict: bool = False) -> list[tuple["DataStream", Exception],
```
"""
self.load()
exceptions = []
for stream in self:
if stream is None:
continue
try:
exceptions += stream.load_all(strict=strict)
except Exception as e:
if strict:
raise e
exceptions.append((stream, e))
return exceptions
stream.load_all(strict=strict)
if stream.has_error and strict:
cast(_typing.ErrorOnLoad, stream.data).raise_from_error()
return self


TDataStream = TypeVar("TDataStream", bound=DataStream[Any, Any])
Expand Down Expand Up @@ -411,7 +463,7 @@ def at(self) -> _At[TDataStream]:
return self._at

@override
def load(self):
def load(self) -> Self:
"""Load data for this collection.

Overrides the base method to add validation that loaded data is a list of DataStreams.
Expand Down
11 changes: 5 additions & 6 deletions src/contraqctor/qc/contract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as t

from .._typing import ErrorOnLoad
from ..contract.base import DataStream
from .base import Suite

Expand Down Expand Up @@ -36,9 +37,7 @@ class ContractTestSuite(Suite):
```
"""

def __init__(
self, loading_errors: list[tuple[DataStream, Exception]], exclude: t.Optional[list[DataStream]] = None
):
def __init__(self, loading_errors: list[ErrorOnLoad], exclude: t.Optional[list[DataStream]] = None):
"""Initialize the contract test suite.

Args:
Expand All @@ -51,9 +50,9 @@ def __init__(

def test_has_errors_on_load(self):
"""Check if any non-excluded data streams had loading errors."""
errors = [(ds, err) for ds, err in self.loading_errors if ds not in self.exclude]
errors = [err for err in self.loading_errors if err.data_stream not in self.exclude]
if errors:
str_errors = "\n".join([f"{ds.resolved_name}" for ds, _ in errors])
str_errors = "\n".join([f"{err.data_stream.resolved_name}" for err in errors])
return self.fail_test(
None,
f"The following DataStreams raised errors on load: \n {str_errors}",
Expand All @@ -64,7 +63,7 @@ def test_has_errors_on_load(self):

def test_has_excluded_as_warnings(self):
"""Check if any excluded data streams had loading errors and report as warnings."""
warnings = [(ds, err) for ds, err in self.loading_errors if ds in self.exclude]
warnings = [err for err in self.loading_errors if err.data_stream in self.exclude]
if warnings:
return self.warn_test(
None,
Expand Down
Empty file removed tests/__init__.py
Empty file.
2 changes: 2 additions & 0 deletions tests/test_contract/conftest.py → tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import tempfile
import typing as t
from pathlib import Path
Expand All @@ -7,6 +8,7 @@
import pytest
from pydantic import BaseModel

os.environ["MPLBACKEND"] = "Agg"
from contraqctor.contract.base import (
DataStream,
FilePathBaseParam,
Expand Down
29 changes: 22 additions & 7 deletions tests/test_contract/test_core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import pytest
from conftest import SimpleDataStream, SimpleParams

from contraqctor import _typing
from contraqctor.contract.base import DataStreamCollection

from .conftest import SimpleDataStream, SimpleParams


class TestDataStream:
"""Tests for the DataStream class."""
Expand Down Expand Up @@ -76,6 +75,19 @@ def test_invalid_name(self, text_file):
name="test::invalid", description="Test stream", reader_params=SimpleParams(path=text_file)
)

def test_clear_data(self, text_file):
"""Test clearing loaded data."""
stream = SimpleDataStream(name="test", reader_params=SimpleParams(path=text_file))

stream.load()
assert stream.has_data

stream.clear()
assert not stream.has_data

with pytest.raises(ValueError):
_ = stream.data # Accessing data after clearing should raise ValueError


class TestDataStreamCollection:
"""Tests for the DataStreamCollection anonymous class."""
Expand Down Expand Up @@ -225,7 +237,7 @@ def test_load_all_success(self, text_file):
collection = DataStreamCollection(name="collection", data_streams=[stream1, stream2])

result = collection.load_all()
assert result == [] # No exceptions
assert result.collect_errors() == []
assert stream1.has_data
assert stream2.has_data

Expand All @@ -239,14 +251,17 @@ def test_load_all_with_exception(self, text_file, temp_dir):
collection = DataStreamCollection(name="collection", data_streams=[stream1, stream2])

result = collection.load_all()

assert len(result) == 1
assert result[0][0] == stream2
assert isinstance(result[0][1], FileNotFoundError)
errors = result.collect_errors()
assert len(errors) == 1
assert errors[0].data_stream == stream2
assert isinstance(errors[0].exception, FileNotFoundError)

assert stream1.has_data
assert not stream2.has_data

with pytest.raises(FileNotFoundError):
raise errors[0].exception

def test_load_all_strict(self, text_file, temp_dir):
"""Test load_all with strict=True."""
stream1 = SimpleDataStream(name="stream1", reader_params=SimpleParams(path=text_file))
Expand Down
3 changes: 1 addition & 2 deletions tests/test_contract/test_json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

import pandas as pd
from conftest import MockModel

from contraqctor.contract.json import (
Json,
Expand All @@ -12,8 +13,6 @@
PydanticModelParams,
)

from .conftest import MockModel


class TestJson:
"""Tests for the Json class."""
Expand Down
Loading