Skip to content

Commit c182858

Browse files
committed
feat: simplified registry
1 parent 9a76467 commit c182858

File tree

13 files changed

+142
-85
lines changed

13 files changed

+142
-85
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ uv run pytest tests/ --cov=dataframe_expectations
5555

5656
**Basic usage with Pandas:**
5757
```python
58-
from dataframe_expectations.expectations_suite import DataFrameExpectationsSuite
58+
from dataframe_expectations.suite import DataFrameExpectationsSuite
5959
import pandas as pd
6060

6161
# Build a suite with expectations
@@ -82,7 +82,7 @@ runner.run(df)
8282

8383
**PySpark example:**
8484
```python
85-
from dataframe_expectations.expectations_suite import DataFrameExpectationsSuite
85+
from dataframe_expectations.suite import DataFrameExpectationsSuite
8686
from pyspark.sql import SparkSession
8787

8888
# Initialize Spark session
@@ -116,7 +116,7 @@ runner.run(df)
116116

117117
**Decorator pattern for automatic validation:**
118118
```python
119-
from dataframe_expectations.expectations_suite import DataFrameExpectationsSuite
119+
from dataframe_expectations.suite import DataFrameExpectationsSuite
120120
from pyspark.sql import SparkSession
121121

122122
# Initialize Spark session

dataframe_expectations/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""DataFrame Expectations - A validation library for pandas and PySpark DataFrames."""
22

3-
__version__ = "0.3.0"
3+
from importlib.metadata import version
4+
5+
__version__ = version("dataframe-expectations")
46

57
__all__ = []

dataframe_expectations/expectations/aggregation/unique.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
from pandas import DataFrame as PandasDataFrame
55
from pyspark.sql import DataFrame as PySparkDataFrame
66
from pyspark.sql import functions as F
7-
8-
from dataframe_expectations.core.types import DataFrameLike, DataFrameType
97
from dataframe_expectations.core.aggregation_expectation import (
108
DataFrameAggregationExpectation,
119
)
1210
from dataframe_expectations.core.types import (
1311
ExpectationCategory,
1412
ExpectationSubcategory,
13+
DataFrameLike,
14+
DataFrameType,
1515
)
1616
from dataframe_expectations.registry import register_expectation
1717
from dataframe_expectations.core.utils import requires_params

dataframe_expectations/registry.py

Lines changed: 99 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from typing import Any, Callable, Dict, Optional
2+
from typing import Any, Callable, Dict, Optional, Tuple
33

44
from dataframe_expectations.core.expectation import DataFrameExpectation
55
from dataframe_expectations.core.types import (
@@ -11,12 +11,20 @@
1111

1212
logger = setup_logger(__name__)
1313

14+
# Type alias for registry entry (factory function + metadata)
15+
FactoryFunction = Callable[..., DataFrameExpectation]
16+
RegistryEntry = Tuple[FactoryFunction, ExpectationMetadata]
17+
1418

1519
class DataFrameExpectationRegistry:
1620
"""Registry for dataframe expectations."""
1721

18-
_expectations: Dict[str, Callable[..., DataFrameExpectation]] = {}
19-
_metadata: Dict[str, ExpectationMetadata] = {}
22+
# Primary registry: keyed by suite_method_name for O(1) suite access
23+
_registry: Dict[str, RegistryEntry] = {}
24+
25+
# Secondary index: maps expectation_name -> suite_method_name for O(1) lookups
26+
_by_name: Dict[str, str] = {}
27+
2028
_loaded: bool = False
2129

2230
@classmethod
@@ -41,21 +49,31 @@ def register(
4149
:return: Decorator function.
4250
"""
4351

44-
def decorator(func: Callable[..., DataFrameExpectation]):
52+
def decorator(func: FactoryFunction) -> FactoryFunction:
4553
expectation_name = name
4654

4755
logger.debug(
4856
f"Registering expectation '{expectation_name}' with function {func.__name__}"
4957
)
5058

51-
# Check if the name is already registered
52-
if expectation_name in cls._expectations:
53-
error_message = f"Expectation '{expectation_name}' is already registered."
59+
suite_method = suite_method_name or cls._convert_to_suite_method(expectation_name)
60+
61+
# Check for duplicate suite method name
62+
if suite_method in cls._registry:
63+
existing_metadata = cls._registry[suite_method][1]
64+
error_message = (
65+
f"Suite method '{suite_method}' is already registered by expectation '{existing_metadata.expectation_name}'. "
66+
f"Cannot register '{expectation_name}'."
67+
)
5468
logger.error(error_message)
5569
raise ValueError(error_message)
5670

57-
# Register factory function
58-
cls._expectations[expectation_name] = func
71+
# Check for duplicate expectation name
72+
if expectation_name in cls._by_name:
73+
existing_suite_method = cls._by_name[expectation_name]
74+
error_message = f"Expectation '{expectation_name}' is already registered with suite method '{existing_suite_method}'."
75+
logger.error(error_message)
76+
raise ValueError(error_message)
5977

6078
# Extract params from @requires_params if present
6179
extracted_params = []
@@ -64,10 +82,8 @@ def decorator(func: Callable[..., DataFrameExpectation]):
6482
extracted_params = list(func._required_params)
6583
extracted_types = getattr(func, "_param_types", {})
6684

67-
# Store metadata
68-
cls._metadata[expectation_name] = ExpectationMetadata(
69-
suite_method_name=suite_method_name
70-
or cls._convert_to_suite_method(expectation_name),
85+
metadata = ExpectationMetadata(
86+
suite_method_name=suite_method,
7187
pydoc=pydoc,
7288
category=category,
7389
subcategory=subcategory,
@@ -78,6 +94,12 @@ def decorator(func: Callable[..., DataFrameExpectation]):
7894
expectation_name=expectation_name,
7995
)
8096

97+
# Store in primary registry
98+
cls._registry[suite_method] = (func, metadata)
99+
100+
# Store in secondary index
101+
cls._by_name[expectation_name] = suite_method
102+
81103
return func
82104

83105
return decorator
@@ -93,8 +115,9 @@ def _convert_to_suite_method(cls, expectation_name: str) -> str:
93115
ExpectationValueGreaterThan -> expect_value_greater_than
94116
ExpectationMinRows -> expect_min_rows
95117
"""
96-
# Remove 'Expectation' prefix
118+
97119
name = re.sub(r"^Expectation", "", expectation_name)
120+
98121
# Convert CamelCase to snake_case
99122
snake = re.sub(r"([A-Z]+)([A-Z][a-z])", r"\1_\2", name)
100123
snake = re.sub(r"([a-z\d])([A-Z])", r"\1_\2", snake)
@@ -141,21 +164,28 @@ def _load_all_expectations(cls):
141164
def get_expectation(cls, expectation_name: str, **kwargs) -> DataFrameExpectation:
142165
"""Get an expectation instance by name.
143166
167+
Note: This method is kept for backward compatibility with tests.
168+
The suite uses get_expectation_by_suite_method() for better performance.
169+
144170
:param expectation_name: The name of the expectation.
145171
:param kwargs: Parameters to pass to the expectation factory function.
146172
:return: An instance of DataFrameExpectation.
147173
"""
148-
cls._ensure_loaded() # Lazy load expectations
174+
cls._ensure_loaded()
149175
logger.debug(f"Retrieving expectation '{expectation_name}' with arguments: {kwargs}")
150-
if expectation_name not in cls._expectations:
176+
177+
if expectation_name not in cls._by_name:
151178
available = cls.list_expectations()
152179
error_message = (
153180
f"Unknown expectation '{expectation_name}'. "
154181
f"Available expectations: {', '.join(available)}"
155182
)
156183
logger.error(error_message)
157184
raise ValueError(error_message)
158-
return cls._expectations[expectation_name](**kwargs)
185+
186+
suite_method = cls._by_name[expectation_name]
187+
factory, metadata = cls._registry[suite_method]
188+
return factory(**kwargs)
159189

160190
@classmethod
161191
def get_metadata(cls, expectation_name: str) -> ExpectationMetadata:
@@ -166,9 +196,13 @@ def get_metadata(cls, expectation_name: str) -> ExpectationMetadata:
166196
:raises ValueError: If expectation not found.
167197
"""
168198
cls._ensure_loaded()
169-
if expectation_name not in cls._metadata:
199+
200+
if expectation_name not in cls._by_name:
170201
raise ValueError(f"No metadata found for expectation '{expectation_name}'")
171-
return cls._metadata[expectation_name]
202+
203+
suite_method = cls._by_name[expectation_name]
204+
factory, metadata = cls._registry[suite_method]
205+
return metadata
172206

173207
@classmethod
174208
def get_all_metadata(cls) -> Dict[str, ExpectationMetadata]:
@@ -177,7 +211,35 @@ def get_all_metadata(cls) -> Dict[str, ExpectationMetadata]:
177211
:return: Dictionary mapping expectation names to their metadata.
178212
"""
179213
cls._ensure_loaded()
180-
return cls._metadata.copy()
214+
return {metadata.expectation_name: metadata for _, (_, metadata) in cls._registry.items()}
215+
216+
@classmethod
217+
def get_expectation_by_suite_method(
218+
cls, suite_method_name: str, **kwargs
219+
) -> DataFrameExpectation:
220+
"""Get an expectation instance by suite method name.
221+
222+
:param suite_method_name: The suite method name (e.g., 'expect_value_greater_than').
223+
:param kwargs: Parameters to pass to the expectation factory function.
224+
:return: An instance of DataFrameExpectation.
225+
:raises ValueError: If suite method not found.
226+
"""
227+
cls._ensure_loaded()
228+
logger.debug(
229+
f"Retrieving expectation for suite method '{suite_method_name}' with arguments: {kwargs}"
230+
)
231+
232+
if suite_method_name not in cls._registry:
233+
available = list(cls._registry.keys())
234+
error_message = (
235+
f"Unknown suite method '{suite_method_name}'. "
236+
f"Available methods: {', '.join(available[:10])}..."
237+
)
238+
logger.error(error_message)
239+
raise ValueError(error_message)
240+
241+
factory, metadata = cls._registry[suite_method_name]
242+
return factory(**kwargs)
181243

182244
@classmethod
183245
def get_suite_method_mapping(cls) -> Dict[str, str]:
@@ -187,16 +249,19 @@ def get_suite_method_mapping(cls) -> Dict[str, str]:
187249
to expectation names (e.g., 'ExpectationValueGreaterThan').
188250
"""
189251
cls._ensure_loaded()
190-
return {meta.suite_method_name: exp_name for exp_name, meta in cls._metadata.items()}
252+
return {
253+
suite_method: metadata.expectation_name
254+
for suite_method, (_, metadata) in cls._registry.items()
255+
}
191256

192257
@classmethod
193258
def list_expectations(cls) -> list:
194259
"""List all registered expectation names.
195260
196261
:return: List of registered expectation names.
197262
"""
198-
cls._ensure_loaded() # Lazy load expectations
199-
return list(cls._expectations.keys())
263+
cls._ensure_loaded()
264+
return [metadata.expectation_name for _, (_, metadata) in cls._registry.items()]
200265

201266
@classmethod
202267
def remove_expectation(cls, expectation_name: str):
@@ -205,23 +270,25 @@ def remove_expectation(cls, expectation_name: str):
205270
:param expectation_name: The name of the expectation to remove.
206271
:raises ValueError: If expectation not found.
207272
"""
208-
cls._ensure_loaded() # Lazy load expectations
273+
cls._ensure_loaded()
209274
logger.debug(f"Removing expectation '{expectation_name}'")
210-
if expectation_name in cls._expectations:
211-
del cls._expectations[expectation_name]
212-
if expectation_name in cls._metadata:
213-
del cls._metadata[expectation_name]
214-
else:
275+
276+
if expectation_name not in cls._by_name:
215277
error_message = f"Expectation '{expectation_name}' not found."
216278
logger.error(error_message)
217279
raise ValueError(error_message)
218280

281+
# Remove from both dictionaries
282+
suite_method = cls._by_name[expectation_name]
283+
del cls._registry[suite_method]
284+
del cls._by_name[expectation_name]
285+
219286
@classmethod
220287
def clear_expectations(cls):
221288
"""Clear all registered expectations."""
222-
logger.debug(f"Clearing {len(cls._expectations)} expectations from the registry")
223-
cls._expectations.clear()
224-
cls._metadata.clear()
289+
logger.debug(f"Clearing {len(cls._registry)} expectations from the registry")
290+
cls._registry.clear()
291+
cls._by_name.clear()
225292
cls._loaded = False # Allow reloading
226293

227294

dataframe_expectations/suite.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ def run(
108108
pyspark_df = cast(PySparkDataFrame, data_frame)
109109
was_already_cached = pyspark_df.is_cached
110110

111-
# Cache the DataFrame if it wasn't already cached
112111
if not was_already_cached:
113112
logger.debug("Caching PySpark DataFrame for expectations suite execution")
114113
pyspark_df.cache()
@@ -206,7 +205,6 @@ def wrapper(*args, **kwargs):
206205
logger.info(f"Validating DataFrame returned from '{f.__name__}'")
207206
self.run(data_frame=result)
208207

209-
# Return the original DataFrame if validation passes
210208
return result
211209

212210
return wrapper
@@ -254,42 +252,32 @@ def __getattr__(self, name: str):
254252
if not name.startswith("expect_"):
255253
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
256254

257-
mapping = DataFrameExpectationRegistry.get_suite_method_mapping()
255+
# Create and return the dynamic method - validation happens in _create_expectation_method
256+
return self._create_expectation_method(name)
258257

259-
# Check if this method exists in the registry
260-
if name not in mapping:
261-
available = list(mapping.keys())
262-
raise AttributeError(
263-
f"Unknown expectation method '{name}'. "
264-
f"Available methods: {', '.join(available[:5])}..."
265-
)
266-
267-
expectation_name = mapping[name]
268-
269-
# Create and return the dynamic method
270-
return self._create_expectation_method(expectation_name, name)
271-
272-
def _create_expectation_method(self, expectation_name: str, method_name: str):
258+
def _create_expectation_method(self, suite_method_name: str):
273259
"""
274260
Create a dynamic expectation method.
275261
276-
Returns a closure that captures the expectation_name and self.
262+
Returns a closure that captures the suite_method_name and self.
277263
"""
278264

279265
def dynamic_method(**kwargs):
280266
"""Dynamically generated expectation method."""
281-
expectation = DataFrameExpectationRegistry.get_expectation(
282-
expectation_name=expectation_name, **kwargs
283-
)
267+
try:
268+
expectation = DataFrameExpectationRegistry.get_expectation_by_suite_method(
269+
suite_method_name=suite_method_name, **kwargs
270+
)
271+
except ValueError as e:
272+
raise AttributeError(str(e)) from e
284273

285274
logger.info(f"Adding expectation: {expectation}")
286275

287-
# Add to internal list
288276
self.__expectations.append(expectation)
289277
return self
290278

291279
# Set helpful name for debugging
292-
dynamic_method.__name__ = method_name
280+
dynamic_method.__name__ = suite_method_name
293281

294282
return dynamic_method
295283

File renamed without changes.

0 commit comments

Comments
 (0)