Skip to content

Commit e0857b4

Browse files
committed
feat: add suite_result and tagging
1 parent 509dcc3 commit e0857b4

File tree

54 files changed

+1380
-239
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1380
-239
lines changed

.github/workflows/main.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
uv run python scripts/sanity_checks.py
3434
- name: Run tests
3535
run: |
36-
uv run pytest tests/ --cov=dataframe_expectations
36+
uv run pytest tests/ -n auto --tb=line --cov=dataframe_expectations
3737
3838
lint:
3939
runs-on: ubuntu-latest

dataframe_expectations/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,22 @@
99
# Catch all exceptions to handle various edge cases in different environments
1010
__version__ = "0.0.0.dev0"
1111

12-
__all__ = []
12+
from dataframe_expectations.core.suite_result import (
13+
ExpectationResult,
14+
SuiteExecutionResult,
15+
serialize_violations,
16+
)
17+
from dataframe_expectations.suite import (
18+
DataFrameExpectationsSuite,
19+
DataFrameExpectationsSuiteRunner,
20+
DataFrameExpectationsSuiteFailure,
21+
)
22+
23+
__all__ = [
24+
"ExpectationResult",
25+
"SuiteExecutionResult",
26+
"serialize_violations",
27+
"DataFrameExpectationsSuite",
28+
"DataFrameExpectationsSuiteRunner",
29+
"DataFrameExpectationsSuiteFailure",
30+
]
Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
11
"""Core base classes and interfaces for DataFrame expectations."""
22

3-
__all__ = []
3+
from dataframe_expectations.core.suite_result import (
4+
ExpectationResult,
5+
ExpectationStatus,
6+
SuiteExecutionResult,
7+
serialize_violations,
8+
)
9+
10+
__all__ = [
11+
"ExpectationResult",
12+
"ExpectationStatus",
13+
"SuiteExecutionResult",
14+
"serialize_violations",
15+
]

dataframe_expectations/core/aggregation_expectation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import abstractmethod
2-
from typing import List, Union
2+
from typing import List, Optional, Union
33

44
from dataframe_expectations.core.types import DataFrameLike, DataFrameType
55
from dataframe_expectations.core.expectation import DataFrameExpectation
@@ -20,6 +20,7 @@ def __init__(
2020
expectation_name: str,
2121
column_names: List[str],
2222
description: str,
23+
tags: Optional[List[str]] = None,
2324
):
2425
"""
2526
Template for implementing DataFrame aggregation expectations, where data is first aggregated
@@ -28,7 +29,10 @@ def __init__(
2829
:param expectation_name: The name of the expectation. This will be used during logging.
2930
:param column_names: The list of column names to aggregate on.
3031
:param description: A description of the expectation used in logging.
32+
:param tags: Optional tags as list of strings in "key:value" format.
33+
Example: ["priority:high", "env:test"]
3134
"""
35+
super().__init__(tags=tags)
3236
self.expectation_name = expectation_name
3337
self.column_names = column_names
3438
self.description = description

dataframe_expectations/core/column_expectation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable
1+
from typing import Callable, List, Optional
22

33
from dataframe_expectations.core.types import DataFrameLike, DataFrameType
44
from dataframe_expectations.core.expectation import DataFrameExpectation
@@ -23,6 +23,7 @@ def __init__(
2323
fn_violations_pyspark: Callable,
2424
description: str,
2525
error_message: str,
26+
tags: Optional[List[str]] = None,
2627
):
2728
"""
2829
Template for implementing DataFrame column expectations, where a column value is tested against a
@@ -34,7 +35,10 @@ def __init__(
3435
:param fn_violations_pyspark: Function to find violations in a PySpark DataFrame.
3536
:param description: A description of the expectation used in logging.
3637
:param error_message: The error message to return if the expectation fails.
38+
:param tags: Optional tags as list of strings in "key:value" format.
39+
Example: ["priority:high", "env:test"]
3740
"""
41+
super().__init__(tags=tags)
3842
self.column_name = column_name
3943
self.expectation_name = expectation_name
4044
self.fn_violations_pandas = fn_violations_pandas

dataframe_expectations/core/expectation.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import cast
2+
from typing import List, Optional, cast
33

44
from pandas import DataFrame as PandasDataFrame
55
from pyspark.sql import DataFrame as PySparkDataFrame
@@ -12,6 +12,7 @@
1212
PySparkConnectDataFrame = None # type: ignore[misc,assignment]
1313

1414
from dataframe_expectations.core.types import DataFrameLike, DataFrameType
15+
from dataframe_expectations.core.tagging import TagSet
1516
from dataframe_expectations.result_message import (
1617
DataFrameExpectationResultMessage,
1718
)
@@ -22,6 +23,14 @@ class DataFrameExpectation(ABC):
2223
Base class for DataFrame expectations.
2324
"""
2425

26+
def __init__(self, tags: Optional[List[str]] = None):
27+
"""
28+
Initialize the base expectation with optional tags.
29+
:param tags: Optional tags as list of strings in "key:value" format.
30+
Example: ["priority:high", "env:test"]
31+
"""
32+
self.tags = TagSet(tags)
33+
2534
def get_expectation_name(self) -> str:
2635
"""
2736
Returns the class name as the expectation name.
@@ -48,29 +57,31 @@ def infer_data_frame_type(cls, data_frame: DataFrameLike) -> DataFrameType:
4857
"""
4958
Infer the DataFrame type based on the provided DataFrame.
5059
"""
51-
if isinstance(data_frame, PandasDataFrame):
52-
return DataFrameType.PANDAS
53-
elif isinstance(data_frame, PySparkDataFrame):
54-
return DataFrameType.PYSPARK
55-
elif PySparkConnectDataFrame is not None and isinstance(
56-
data_frame, PySparkConnectDataFrame
57-
):
58-
return DataFrameType.PYSPARK
59-
else:
60-
raise ValueError(f"Unsupported DataFrame type: {type(data_frame)}")
60+
match data_frame:
61+
case PandasDataFrame():
62+
return DataFrameType.PANDAS
63+
case PySparkDataFrame():
64+
return DataFrameType.PYSPARK
65+
case _ if PySparkConnectDataFrame is not None and isinstance(
66+
data_frame, PySparkConnectDataFrame
67+
):
68+
return DataFrameType.PYSPARK
69+
case _:
70+
raise ValueError(f"Unsupported DataFrame type: {type(data_frame)}")
6171

6272
def validate(self, data_frame: DataFrameLike, **kwargs):
6373
"""
6474
Validate the DataFrame against the expectation.
6575
"""
6676
data_frame_type = self.infer_data_frame_type(data_frame)
6777

68-
if data_frame_type == DataFrameType.PANDAS:
69-
return self.validate_pandas(data_frame=data_frame, **kwargs)
70-
elif data_frame_type == DataFrameType.PYSPARK:
71-
return self.validate_pyspark(data_frame=data_frame, **kwargs)
72-
else:
73-
raise ValueError(f"Unsupported DataFrame type: {data_frame_type}")
78+
match data_frame_type:
79+
case DataFrameType.PANDAS:
80+
return self.validate_pandas(data_frame=data_frame, **kwargs)
81+
case DataFrameType.PYSPARK:
82+
return self.validate_pyspark(data_frame=data_frame, **kwargs)
83+
case _:
84+
raise ValueError(f"Unsupported DataFrame type: {data_frame_type}")
7485

7586
@abstractmethod
7687
def validate_pandas(
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""Suite execution result models for capturing validation outcomes."""
2+
3+
from datetime import datetime
4+
from typing import Any, Dict, List, Literal, Optional
5+
6+
from pydantic import BaseModel, Field, computed_field
7+
8+
from dataframe_expectations.core.types import DataFrameType, DataFrameLike
9+
from dataframe_expectations.core.tagging import TagSet
10+
11+
12+
from enum import Enum
13+
14+
15+
class ExpectationStatus(str, Enum):
16+
PASSED = "passed"
17+
FAILED = "failed"
18+
SKIPPED = "skipped"
19+
20+
21+
class ExpectationResult(BaseModel):
22+
"""
23+
Representation of a single expectation result within a suite execution.
24+
Captures the outcome (passed, failed, skipped) using status.
25+
Does not store raw dataframes, only serialized violation samples.
26+
"""
27+
28+
expectation_name: str = Field(..., description="Name of the expectation class")
29+
description: str = Field(..., description="Human-readable description of the expectation")
30+
status: ExpectationStatus = Field(..., description="Outcome status: passed, failed, or skipped")
31+
tags: Optional[TagSet] = Field(
32+
default=None, description="User-defined tags for this specific expectation"
33+
)
34+
error_message: Optional[str] = Field(
35+
default=None, description="Error message if expectation failed"
36+
)
37+
violation_count: Optional[int] = Field(
38+
default=None, description="Total count of violations (if applicable)"
39+
)
40+
violation_sample: Optional[List[Dict[str, Any]]] = Field(
41+
default=None,
42+
description="Sample of violations as list of dicts (limited by violation_sample_limit)",
43+
)
44+
45+
model_config = {"frozen": True, "arbitrary_types_allowed": True} # Make immutable, allow TagSet
46+
47+
48+
class SuiteExecutionResult(BaseModel):
49+
"""Result of a complete suite execution.
50+
Captures all metadata about the suite run including timing, dataframe info,
51+
and individual expectation results. Does not store raw dataframes.
52+
"""
53+
54+
suite_name: Optional[str] = Field(default=None, description="Optional name for the suite")
55+
context: Dict[str, Any] = Field(
56+
default_factory=dict, description="Additional runtime metadata (e.g., job_id, environment)"
57+
)
58+
applied_filters: TagSet = Field(
59+
default_factory=TagSet, description="Tag filters that were applied to select expectations"
60+
)
61+
tag_match_mode: Optional[Literal["any", "all"]] = Field(
62+
default=None, description="How tags were matched: 'any' (OR) or 'all' (AND)"
63+
)
64+
results: List[ExpectationResult] = Field(
65+
..., description="Results for each expectation in execution order (including skipped)"
66+
)
67+
start_time: datetime = Field(..., description="Suite execution start timestamp")
68+
end_time: datetime = Field(..., description="Suite execution end timestamp")
69+
dataframe_type: DataFrameType = Field(..., description="Type of dataframe validated")
70+
dataframe_row_count: int = Field(..., description="Number of rows in validated dataframe")
71+
dataframe_was_cached: bool = Field(
72+
default=False, description="Whether PySpark dataframe was cached during execution"
73+
)
74+
75+
model_config = {"frozen": True, "arbitrary_types_allowed": True} # Make immutable, allow TagSet
76+
77+
@computed_field # type: ignore[misc]
78+
@property
79+
def total_duration_seconds(self) -> float:
80+
"""Total execution time in seconds."""
81+
return (self.end_time - self.start_time).total_seconds()
82+
83+
@computed_field # type: ignore[misc]
84+
@property
85+
def total_expectations(self) -> int:
86+
"""Total number of expectations in the suite (including skipped)."""
87+
return len(self.results)
88+
89+
@computed_field # type: ignore[misc]
90+
@property
91+
def total_passed(self) -> int:
92+
"""Number of expectations that passed."""
93+
return sum(1 for r in self.results if r.status == ExpectationStatus.PASSED)
94+
95+
@computed_field # type: ignore[misc]
96+
@property
97+
def total_failed(self) -> int:
98+
"""Number of expectations that failed."""
99+
return sum(1 for r in self.results if r.status == ExpectationStatus.FAILED)
100+
101+
@computed_field # type: ignore[misc]
102+
@property
103+
def total_skipped(self) -> int:
104+
"""Number of expectations that were skipped due to tag filtering."""
105+
return sum(1 for r in self.results if r.status == ExpectationStatus.SKIPPED)
106+
107+
@computed_field # type: ignore[misc]
108+
@property
109+
def pass_rate(self) -> float:
110+
"""Percentage of expectations that passed (0.0 to 1.0)."""
111+
executed = self.total_passed + self.total_failed
112+
if executed == 0:
113+
return 1.0
114+
return self.total_passed / executed
115+
116+
@computed_field # type: ignore[misc]
117+
@property
118+
def success(self) -> bool:
119+
"""Whether all executed expectations passed (ignores skipped)."""
120+
return self.total_failed == 0
121+
122+
@computed_field # type: ignore[misc]
123+
@property
124+
def passed_expectations(self) -> List[ExpectationResult]:
125+
"""List of expectations that passed."""
126+
return [r for r in self.results if r.status == ExpectationStatus.PASSED]
127+
128+
@computed_field # type: ignore[misc]
129+
@property
130+
def failed_expectations(self) -> List[ExpectationResult]:
131+
"""List of expectations that failed."""
132+
return [r for r in self.results if r.status == ExpectationStatus.FAILED]
133+
134+
@computed_field # type: ignore[misc]
135+
@property
136+
def skipped_expectations(self) -> List[ExpectationResult]:
137+
"""List of expectations that were skipped due to tag filtering."""
138+
return [r for r in self.results if r.status == ExpectationStatus.SKIPPED]
139+
140+
141+
def serialize_violations(
142+
violations_df: Optional[DataFrameLike],
143+
df_type: DataFrameType,
144+
limit: int = 5,
145+
) -> tuple[Optional[int], Optional[List[Dict[str, Any]]]]:
146+
"""Serialize violation dataframe to count and sample for storage.
147+
148+
Converts dataframes to JSON-serializable format without storing raw dataframes.
149+
150+
:param violations_df: DataFrame containing violations (pandas or PySpark).
151+
:param df_type: Type of the violations dataframe.
152+
:param limit: Maximum number of violation rows to include in sample.
153+
:return: Tuple of (total_count, sample_as_list_of_dicts).
154+
"""
155+
if violations_df is None:
156+
return None, None
157+
158+
count: Optional[int] = None
159+
sample: Optional[list[dict[str, Any]]] = None
160+
161+
try:
162+
if df_type == DataFrameType.PANDAS:
163+
pandas_df = violations_df # type: ignore[assignment]
164+
count = len(pandas_df) # type: ignore[arg-type]
165+
sample = pandas_df.head(limit).to_dict("records") # type: ignore[assignment,union-attr]
166+
elif df_type == DataFrameType.PYSPARK:
167+
pyspark_df = violations_df # type: ignore[assignment]
168+
count = pyspark_df.count() # type: ignore[assignment]
169+
sample = pyspark_df.limit(limit).toPandas().to_dict("records") # type: ignore[assignment,operator]
170+
171+
return count, sample
172+
except Exception:
173+
# If serialization fails, return None to avoid breaking the suite
174+
return None, None

0 commit comments

Comments
 (0)