Skip to content

Commit a19126a

Browse files
committed
chore: use enums for tag match mode instead of str
1 parent e702d53 commit a19126a

File tree

7 files changed

+97
-74
lines changed

7 files changed

+97
-74
lines changed

dataframe_expectations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
SuiteExecutionResult,
1515
serialize_violations,
1616
)
17+
from dataframe_expectations.core.types import TagMatchMode
1718
from dataframe_expectations.suite import (
1819
DataFrameExpectationsSuite,
1920
DataFrameExpectationsSuiteRunner,
@@ -27,4 +28,5 @@
2728
"DataFrameExpectationsSuite",
2829
"DataFrameExpectationsSuiteRunner",
2930
"DataFrameExpectationsSuiteFailure",
31+
"TagMatchMode",
3032
]

dataframe_expectations/core/suite_result.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Suite execution result models for capturing validation outcomes."""
22

33
from datetime import datetime
4-
from typing import Any, Dict, List, Literal, Optional
4+
from typing import Any, Dict, List, Optional
55

66
from pydantic import BaseModel, Field, computed_field
77

8-
from dataframe_expectations.core.types import DataFrameType, DataFrameLike
8+
from dataframe_expectations.core.types import DataFrameType, DataFrameLike, TagMatchMode
99
from dataframe_expectations.core.tagging import TagSet
1010
import logging
1111

@@ -44,7 +44,7 @@ class ExpectationResult(BaseModel):
4444
description="Sample of violations as list of dicts (limited by violation_sample_limit)",
4545
)
4646

47-
model_config = {"frozen": True, "arbitrary_types_allowed": True} # Make immutable, allow TagSet
47+
model_config = {"frozen": True} # Make immutable
4848

4949

5050
class SuiteExecutionResult(BaseModel):
@@ -60,8 +60,9 @@ class SuiteExecutionResult(BaseModel):
6060
applied_filters: TagSet = Field(
6161
default_factory=TagSet, description="Tag filters that were applied to select expectations"
6262
)
63-
tag_match_mode: Optional[Literal["any", "all"]] = Field(
64-
default=None, description="How tags were matched: 'any' (OR) or 'all' (AND)"
63+
tag_match_mode: Optional[TagMatchMode] = Field(
64+
default=None,
65+
description="How tags were matched: TagMatchMode.ANY (OR) or TagMatchMode.ALL (AND)",
6566
)
6667
results: List[ExpectationResult] = Field(
6768
..., description="Results for each expectation in execution order (including skipped)"
@@ -74,7 +75,7 @@ class SuiteExecutionResult(BaseModel):
7475
default=False, description="Whether PySpark dataframe was cached during execution"
7576
)
7677

77-
model_config = {"frozen": True, "arbitrary_types_allowed": True} # Make immutable, allow TagSet
78+
model_config = {"frozen": True} # Make immutable
7879

7980
@computed_field # type: ignore[misc]
8081
@property

dataframe_expectations/core/tagging.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88

99
from typing import Dict, List, Optional, Set
1010

11+
from pydantic import BaseModel, ConfigDict
1112

12-
class TagSet:
13+
14+
class TagSet(BaseModel):
1315
"""
1416
Collection of tags organized by key, supporting multiple values per key.
1517
@@ -19,7 +21,11 @@ class TagSet:
1921
Tags are specified as strings in "key:value" format.
2022
"""
2123

22-
def __init__(self, tags: Optional[List[str]] = None):
24+
tags: Dict[str, Set[str]] = {}
25+
26+
model_config = ConfigDict(frozen=True) # Make immutable
27+
28+
def __init__(self, tags: Optional[List[str]] = None, **data):
2329
"""
2430
Initialize TagSet from a list of tag strings.
2531
@@ -30,17 +36,22 @@ def __init__(self, tags: Optional[List[str]] = None):
3036
>>> TagSet(["priority:high", "env:test"])
3137
>>> TagSet(["priority:high", "priority:medium"]) # Multiple values for same key
3238
"""
33-
self._tags: Dict[str, Set[str]] = {}
34-
35-
if tags:
39+
# Parse tags if provided as list
40+
if tags is not None:
41+
parsed_tags: Dict[str, Set[str]] = {}
3642
for tag_string in tags:
37-
self._add_tag_string(tag_string)
43+
TagSet._parse_and_add_tag(tag_string, parsed_tags)
44+
data["tags"] = parsed_tags
45+
46+
super().__init__(**data)
3847

39-
def _add_tag_string(self, tag_string: str) -> None:
48+
@staticmethod
49+
def _parse_and_add_tag(tag_string: str, tags_dict: Dict[str, Set[str]]) -> None:
4050
"""
41-
Parse and add a tag string in "key:value" format.
51+
Parse and add a tag string to the provided dictionary.
4252
4353
:param tag_string: Tag string to parse
54+
:param tags_dict: Dictionary to add parsed tag to
4455
:raises ValueError: If format is invalid
4556
"""
4657
tag_string = tag_string.strip()
@@ -59,9 +70,9 @@ def _add_tag_string(self, tag_string: str) -> None:
5970
if not key or not value:
6071
raise ValueError("Tag key and value must be non-empty strings")
6172

62-
if key not in self._tags:
63-
self._tags[key] = set()
64-
self._tags[key].add(value)
73+
if key not in tags_dict:
74+
tags_dict[key] = set()
75+
tags_dict[key].add(value)
6576

6677
def has_any_tag_from(self, other: TagSet) -> bool:
6778
"""
@@ -84,14 +95,14 @@ def has_any_tag_from(self, other: TagSet) -> bool:
8495
other = TagSet(["priority:medium", "env:test"])
8596
self.has_any_tag_from(other) -> True (env:test matches)
8697
"""
87-
if not other._tags:
98+
if not other.tags:
8899
return True # Empty filter matches everything
89100

90101
# OR logic: any key with overlapping values
91-
for key, required_values in other._tags.items():
92-
if key in self._tags:
102+
for key, required_values in other.tags.items():
103+
if key in self.tags:
93104
# Check if there's any overlap between required values and our values
94-
if required_values & self._tags[key]:
105+
if required_values & self.tags[key]:
95106
return True
96107

97108
return False
@@ -117,38 +128,38 @@ def has_all_tags_from(self, other: TagSet) -> bool:
117128
other = TagSet(["priority:high", "env:prod"])
118129
self.has_all_tags_from(other) -> False (env:prod doesn't match)
119130
"""
120-
if not other._tags:
131+
if not other.tags:
121132
return True # Empty filter matches everything
122133

123134
# AND logic: all keys must have ALL required values present
124-
for key, required_values in other._tags.items():
125-
if key not in self._tags:
135+
for key, required_values in other.tags.items():
136+
if key not in self.tags:
126137
return False
127138
# Check if ALL required values are present in our values
128-
if not required_values.issubset(self._tags[key]):
139+
if not required_values.issubset(self.tags[key]):
129140
return False
130141

131142
return True
132143

133144
def is_empty(self) -> bool:
134145
"""Check if TagSet has no tags."""
135-
return len(self._tags) == 0
146+
return len(self.tags) == 0
136147

137148
def __len__(self) -> int:
138149
"""Return total number of unique tags (key:value pairs)."""
139-
return sum(len(values) for values in self._tags.values())
150+
return sum(len(values) for values in self.tags.values())
140151

141152
def __bool__(self) -> bool:
142153
"""Return True if TagSet has any tags."""
143-
return bool(self._tags)
154+
return bool(self.tags)
144155

145156
def __str__(self) -> str:
146157
"""String representation showing all tags."""
147-
tags = []
148-
for key in sorted(self._tags.keys()):
149-
for value in sorted(self._tags[key]):
150-
tags.append(f"{key}:{value}")
151-
return f"TagSet({', '.join(tags)})" if tags else "TagSet(empty)"
158+
tag_list = []
159+
for key in sorted(self.tags.keys()):
160+
for value in sorted(self.tags[key]):
161+
tag_list.append(f"{key}:{value}")
162+
return f"TagSet({', '.join(tag_list)})" if tag_list else "TagSet(empty)"
152163

153164
def __repr__(self) -> str:
154165
return self.__str__()

dataframe_expectations/core/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ class DataFrameType(str, Enum):
1818
PYSPARK = "pyspark"
1919

2020

21+
class TagMatchMode(str, Enum):
22+
"""Enum for tag matching modes."""
23+
24+
ANY = "any" # OR logic: expectation matches if it has ANY of the filter tags
25+
ALL = "all" # AND logic: expectation matches if it has ALL of the filter tags
26+
27+
2128
class ExpectationCategory(str, Enum):
2229
"""Categories for expectations."""
2330

dataframe_expectations/suite.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from functools import wraps
2-
from typing import Any, Callable, Dict, List, Literal, Optional, cast
2+
from typing import Any, Callable, Dict, List, Optional, cast
33

4-
from dataframe_expectations.core.types import DataFrameLike
4+
from dataframe_expectations.core.types import DataFrameLike, TagMatchMode
55
from dataframe_expectations.core.tagging import TagSet
66
from dataframe_expectations.registry import (
77
DataFrameExpectationRegistry,
@@ -62,35 +62,32 @@ class DataFrameExpectationsSuiteRunner:
6262
def _matches_tag_filter(
6363
expectation: Any,
6464
filter_tag_set: TagSet,
65-
tag_match_mode: Literal["any", "all"],
65+
tag_match_mode: TagMatchMode,
6666
) -> bool:
6767
"""
6868
Check if an expectation matches the tag filter criteria.
6969
7070
:param expectation: Expectation instance to check.
7171
:param filter_tag_set: Tag filter to match against.
72-
:param tag_match_mode: Match mode - "any" (OR) or "all" (AND).
72+
:param tag_match_mode: Match mode - TagMatchMode.ANY (OR) or TagMatchMode.ALL (AND).
7373
:return: True if expectation matches filter, False otherwise.
74-
:raises ValueError: If tag_match_mode is invalid.
7574
"""
7675
exp_tag_set = expectation.get_tags()
7776

7877
# Check if expectation matches filter
7978
match tag_match_mode:
80-
case "any":
79+
case TagMatchMode.ANY:
8180
return exp_tag_set.has_any_tag_from(filter_tag_set)
82-
case "all":
81+
case TagMatchMode.ALL:
8382
return exp_tag_set.has_all_tags_from(filter_tag_set)
84-
case _:
85-
raise ValueError(f"Invalid tag_match_mode: {tag_match_mode}")
8683

8784
def __init__(
8885
self,
8986
expectations: List[Any],
9087
suite_name: Optional[str] = None,
9188
violation_sample_limit: int = 5,
9289
tags: Optional[List[str]] = None,
93-
tag_match_mode: Optional[Literal["any", "all"]] = None,
90+
tag_match_mode: Optional[TagMatchMode] = None,
9491
):
9592
"""
9693
Initialize the runner with a list of expectations and metadata.
@@ -101,10 +98,10 @@ def __init__(
10198
:param tags: Optional tag filters as list of strings in "key:value" format.
10299
Example: ["priority:high", "priority:medium"]
103100
If None or empty, all expectations will run.
104-
:param tag_match_mode: How to match tags - "any" (OR logic) or "all" (AND logic).
101+
:param tag_match_mode: How to match tags - TagMatchMode.ANY (OR logic) or TagMatchMode.ALL (AND logic).
105102
Required if tags are provided, must be None if tags are not provided.
106-
- "any": Expectation matches if it has ANY of the filter tags
107-
- "all": Expectation matches if it has ALL of the filter tags
103+
- TagMatchMode.ANY: Expectation matches if it has ANY of the filter tags
104+
- TagMatchMode.ALL: Expectation matches if it has ALL of the filter tags
108105
:raises ValueError: If tag_match_mode is provided without tags, or if tags are provided without tag_match_mode,
109106
or if tag filters result in zero expectations to run.
110107
"""
@@ -122,15 +119,21 @@ def __init__(
122119

123120
if not self.__filter_tag_set.is_empty() and tag_match_mode is None:
124121
raise ValueError(
125-
"tag_match_mode must be specified ('any' or 'all') when tags are provided."
122+
"tag_match_mode must be specified (TagMatchMode.ANY or TagMatchMode.ALL) when tags are provided."
126123
)
127124

128125
self.__tag_match_mode = tag_match_mode
129126

130127
# Filter expectations based on tags and track skipped ones
131128
if not self.__filter_tag_set.is_empty():
132129
# At this point, validation ensures tag_match_mode is not None
133-
assert tag_match_mode is not None
130+
# This check is for type narrowing (mypy/pyright)
131+
if tag_match_mode is None:
132+
# This should never happen due to validation above, but satisfies type checker
133+
raise ValueError(
134+
"tag_match_mode must be specified (TagMatchMode.ANY or TagMatchMode.ALL) when tags are provided."
135+
)
136+
134137
filtered = []
135138
skipped = []
136139
for exp in self.__all_expectations:
@@ -201,7 +204,7 @@ def run(
201204
data_frame: DataFrameLike,
202205
raise_on_failure: bool = True,
203206
context: Optional[Dict[str, Any]] = None,
204-
) -> Optional[SuiteExecutionResult]:
207+
) -> SuiteExecutionResult:
205208
"""
206209
Run all expectations on the provided DataFrame with PySpark caching optimization.
207210
@@ -458,11 +461,11 @@ class DataFrameExpectationsSuite:
458461
runner_all.run(df) # Runs all 3 expectations
459462
460463
# Build runner for high OR medium priority expectations (OR logic)
461-
runner_any = suite.build(tags=["priority:high", "priority:medium"], tag_match_mode="any")
464+
runner_any = suite.build(tags=["priority:high", "priority:medium"], tag_match_mode=TagMatchMode.ANY)
462465
runner_any.run(df) # Runs 2 expectations (age and salary checks)
463466
464467
# Build runner for expectations with both high priority AND compliance category (AND logic)
465-
runner_and = suite.build(tags=["priority:high", "category:compliance"], tag_match_mode="all")
468+
runner_and = suite.build(tags=["priority:high", "category:compliance"], tag_match_mode=TagMatchMode.ALL)
466469
runner_and.run(df) # Runs 1 expectation (age check - has both tags)
467470
"""
468471

@@ -530,7 +533,7 @@ def dynamic_method(tags: Optional[List[str]] = None, **kwargs):
530533
def build(
531534
self,
532535
tags: Optional[List[str]] = None,
533-
tag_match_mode: Optional[Literal["any", "all"]] = None,
536+
tag_match_mode: Optional[TagMatchMode] = None,
534537
) -> DataFrameExpectationsSuiteRunner:
535538
"""
536539
Build an immutable runner from the current expectations.
@@ -542,10 +545,10 @@ def build(
542545
:param tags: Optional tag filters as list of strings in "key:value" format.
543546
Example: ["priority:high", "priority:medium"]
544547
If None or empty, all expectations will be included.
545-
:param tag_match_mode: How to match tags - "any" (OR logic) or "all" (AND logic).
548+
:param tag_match_mode: How to match tags - TagMatchMode.ANY (OR logic) or TagMatchMode.ALL (AND logic).
546549
Required if tags are provided, must be None if tags are not provided.
547-
- "any": Include expectations with ANY of the filter tags
548-
- "all": Include expectations with ALL of the filter tags
550+
- TagMatchMode.ANY: Include expectations with ANY of the filter tags
551+
- TagMatchMode.ALL: Include expectations with ALL of the filter tags
549552
:return: An immutable DataFrameExpectationsSuiteRunner instance.
550553
:raises ValueError: If no expectations have been added, if tag_match_mode validation fails,
551554
or if no expectations match the tag filters.

0 commit comments

Comments
 (0)