Skip to content

Commit 62129a9

Browse files
authored
Merge pull request #332 from tylerriccio33/agg-factory
Validation aggregation factory
2 parents a06985f + cbd33e7 commit 62129a9

File tree

11 files changed

+2258
-21
lines changed

11 files changed

+2258
-21
lines changed

CONTRIBUTING.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,28 @@ The tests are located in the `tests` folder and we use `pytest` for running them
4545

4646
If you create new tests involving snapshots, please ensure that the resulting snapshots are relatively small. After adding snapshots, use `make test-update` (this runs `pytest --snapshot-update`). A subsequent use of `make test` should pass without any issues.
4747

48+
### Creating Aggregation Methods
49+
50+
Aggregation methods are generated dynamically! This is done because they all have the same signature and they're registered on the `Validate` class in the same way. So, to add a new method, go to `pointblank/_agg.py` and add either a comparison or statistical aggregation function.
51+
52+
Comparison functions are defined by `comp_*`, for example `comp_gt` for "greater than". Statistical functions are defined by `agg_*`, for example `agg_sum` for "sum". At build time, these are registered and a grid of all combinations are created:
53+
```{python}
54+
Aggregator = Callable[[nw.DataFrame], Any]
55+
Comparator = Callable[[Any, Any], bool]
56+
57+
AGGREGATOR_REGISTRY: dict[str, Aggregator] = {}
58+
59+
COMPARATOR_REGISTRY: dict[str, Comparator] = {}
60+
```
61+
62+
Once you've added a new method(s), run `make pyi` to generate the updated type stubs in `pointblank/validate.pyi` which contains the new signatures for the aggregation methods. At runtime, or import time to be precise, the methods are added to the `Validate` class and resolved internally through the registry.
63+
```{python}
64+
# pointblank/validate.py
65+
for method in load_validation_method_grid(): # -> `col_sum_*`, `col_mean_*`, etc.
66+
setattr(Validate, method, make_agg_validator(method))
67+
```
68+
69+
At this point, the methods will exist AND the docs/signature are loaded properly in the type checker and IDE/LSPs, which is very important for usability.
4870
### Linting and Type Checking
4971

5072
We use `ruff` for linting, the settings used are fairly loose and objective. Linting is run in pre-commit in CI. You can run it locally with `make lint`. Type checking is currently not enforced, but we intend on gradually typing the codebase. You can run `make type` to run Astral's new experimental type checker `ty`. Feel free to leverage type hints and occasionally type checking but it's not obligatory at this time.

Makefile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
.PHONY: check
22

3+
.PHONY: pyi
4+
pyi: ## Generate .pyi stub files
5+
@uv run stubgen ./pointblank/validate.py \
6+
--include-private \
7+
-o .
8+
@uv run scripts/generate_agg_validate_pyi.py
9+
310
.PHONY: test
411
test:
512
@uv run pytest tests \

pointblank/_agg.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from __future__ import annotations
2+
3+
import itertools
4+
from collections.abc import Callable
5+
from typing import Any
6+
7+
import narwhals as nw
8+
9+
# TODO: Should take any frame type
10+
Aggregator = Callable[[nw.DataFrame], float | int]
11+
Comparator = Callable[[Any, Any, Any], bool]
12+
13+
AGGREGATOR_REGISTRY: dict[str, Aggregator] = {}
14+
15+
COMPARATOR_REGISTRY: dict[str, Comparator] = {}
16+
17+
18+
def register(fn):
19+
"""Register an aggregator or comparator function."""
20+
name: str = fn.__name__
21+
if name.startswith("comp_"):
22+
COMPARATOR_REGISTRY[name.removeprefix("comp_")] = fn
23+
elif name.startswith("agg_"):
24+
AGGREGATOR_REGISTRY[name.removeprefix("agg_")] = fn
25+
else:
26+
raise NotImplementedError # pragma: no cover
27+
return fn
28+
29+
30+
## Aggregator Functions
31+
@register
32+
def agg_sum(column: nw.DataFrame) -> float:
33+
return column.select(nw.all().sum()).item()
34+
35+
36+
@register
37+
def agg_avg(column: nw.DataFrame) -> float:
38+
return column.select(nw.all().mean()).item()
39+
40+
41+
@register
42+
def agg_sd(column: nw.DataFrame) -> float:
43+
return column.select(nw.all().std()).item()
44+
45+
46+
## Comparator functions:
47+
@register
48+
def comp_eq(real: float, lower: float, upper: float) -> bool:
49+
if lower == upper:
50+
return bool(real == lower)
51+
return _generic_between(real, lower, upper)
52+
53+
54+
@register
55+
def comp_gt(real: float, lower: float, upper: float) -> bool:
56+
return bool(real > lower)
57+
58+
59+
@register
60+
def comp_ge(real: Any, lower: float, upper: float) -> bool:
61+
return bool(real >= lower)
62+
63+
64+
@register
65+
def comp_lt(real: float, lower: float, upper: float) -> bool:
66+
return bool(real < upper)
67+
68+
69+
@register
70+
def comp_le(real: float, lower: float, upper: float) -> bool:
71+
return bool(real <= upper)
72+
73+
74+
def _generic_between(real: Any, lower: Any, upper: Any) -> bool:
75+
"""Call if comparator needs to check between two values."""
76+
return bool(lower <= real <= upper)
77+
78+
79+
def resolve_agg_registries(name: str) -> tuple[Aggregator, Comparator]:
80+
"""Resolve the assertion name to a valid aggregator
81+
82+
Args:
83+
name (str): The name of the assertion.
84+
85+
Returns:
86+
tuple[Aggregator, Comparator]: The aggregator and comparator functions.
87+
"""
88+
name = name.removeprefix("col_")
89+
agg_name, comp_name = name.split("_")[-2:]
90+
91+
aggregator = AGGREGATOR_REGISTRY.get(agg_name)
92+
comparator = COMPARATOR_REGISTRY.get(comp_name)
93+
94+
if aggregator is None: # pragma: no cover
95+
raise ValueError(f"Aggregator '{agg_name}' not found in registry.")
96+
97+
if comparator is None: # pragma: no cover
98+
raise ValueError(f"Comparator '{comp_name}' not found in registry.")
99+
100+
return aggregator, comparator
101+
102+
103+
def is_valid_agg(name: str) -> bool:
104+
try:
105+
resolve_agg_registries(name)
106+
return True
107+
except ValueError:
108+
return False
109+
110+
111+
def load_validation_method_grid() -> tuple[str, ...]:
112+
"""Generate all possible validation methods."""
113+
methods = []
114+
for agg_name, comp_name in itertools.product(
115+
AGGREGATOR_REGISTRY.keys(), COMPARATOR_REGISTRY.keys()
116+
):
117+
method = f"col_{agg_name}_{comp_name}"
118+
methods.append(method)
119+
120+
return tuple(methods)

0 commit comments

Comments
 (0)