Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
89d94a7
add ty
pavelzw Dec 17, 2025
a54715a
introduce ty
pavelzw Dec 17, 2025
ea0ff71
WIP
pavelzw Dec 17, 2025
275f1c3
Address some ty complaints.
kklein Dec 17, 2025
867a959
Merge branch 'main' of github.com:Quantco/datajudge into ty
kklein Dec 17, 2025
816c849
Don't expect children to implement retrieve and compare.
kklein Dec 17, 2025
07e031d
use --no-progress
pavelzw Dec 17, 2025
79ea6a8
Address some more ty complaints.
kklein Dec 18, 2025
2596c26
Fix another ty complaint.
kklein Dec 18, 2025
48b8c86
Add type annotation.
kklein Dec 18, 2025
041ce7f
Remove type igores.
kklein Dec 18, 2025
b7d158c
Bring back instance caching
kklein Dec 19, 2025
99222d3
Tell ty that the first argument to a method doesn't correspond to self.
kklein Dec 19, 2025
2ccfe58
Remove some mypy hacks
kklein Dec 20, 2025
226cef5
Update lock
kklein Jan 6, 2026
d700e93
Merge branch 'main' of github.com:Quantco/datajudge into ty
kklein Jan 14, 2026
979f81b
Merge branch 'main' of github.com:Quantco/datajudge into ty
kklein Jan 19, 2026
b174f39
Add annotation and fix typo
kklein Jan 19, 2026
fd60cfb
Fix method references
kklein Jan 20, 2026
309ddfa
Fix imports
kklein Jan 20, 2026
b27450a
Rework interfaces of select methods
kklein Jan 20, 2026
a58ac3c
Ensure that method overwrites are legitimate overwrites
kklein Jan 20, 2026
3d2af01
Fix field references
kklein Jan 20, 2026
e043bec
Fix more method references.
kklein Jan 20, 2026
e1b733f
Bring back mypy
kklein Jan 20, 2026
a09f601
Use default environment for mypy
kklein Jan 20, 2026
3750c16
Use updated lock
kklein Jan 20, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ repos:
types_or: [python, pyi]
require_serial: true
# mypy
- id: mypy
name: mypy
entry: pixi run -e mypy mypy
- id: ty
name: ty
entry: pixi run -e default ty check --no-progress
language: system
types: [python]
require_serial: true
Expand Down
1,656 changes: 1,247 additions & 409 deletions pixi.lock

Large diffs are not rendered by default.

14 changes: 5 additions & 9 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,16 @@ sqlalchemy = "2.*"
[feature.test.dependencies]
pytest-cov = "*"
pytest-xdist = "*"
ty = "*"
types-colorama = "*"
pandas-stubs = "*"
types-jinja2 = "*"
pytest-html = "*"

[feature.test.target.unix.dependencies]
pytest-memray = "*"
memray = "*"

[feature.mypy.dependencies]
mypy = "*"
types-setuptools = "*"
types-colorama = "*"
pandas-stubs = "*"
types-jinja2 = "*"

[feature.lint.dependencies]
pre-commit = "*"
docformatter = "*"
Expand Down Expand Up @@ -153,5 +151,3 @@ bigquery-py310 = ["bigquery", "py310", "test"]
bigquery-sa1 = ["bigquery", "sa1", "test"]

lint = { features = ["lint"], no-default-feature = true }

mypy = ["mypy"]
13 changes: 4 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,7 @@ known-first-party = ["datajudge"]
quote-style = "double"
indent-style = "space"

[tool.mypy]
python_version = '3.10'
no_implicit_optional = true
allow_empty_bodies = true
check_untyped_defs = true

[[tool.mypy.overrides]]
module = ["scipy.*", "pytest_html"]
ignore_missing_imports = true
[tool.ty.terminal]
error-on-warning = true
[tool.ty.rules]
invalid-method-override = "ignore"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this?

31 changes: 19 additions & 12 deletions src/datajudge/constraints/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

import abc
from collections.abc import Callable, Collection
from dataclasses import dataclass, field
Expand Down Expand Up @@ -146,17 +144,24 @@ def __init__(
not isinstance(output_processors, list)
):
output_processors = [output_processors]
self.output_processors = output_processors
self.output_processors: list[OutputProcessor] | None = output_processors

self.cache_size = cache_size
self._setup_caching()

def _setup_caching(self):
# this has an added benefit of allowing the class to be garbage collected
# We don't use cache or lru_cache decorators since those would lead
# to class-based, not instance-based caching.
#
# Using this approach has the added benefit of allowing the class to be garbage collected
# according to https://rednafi.com/python/lru_cache_on_methods/
# and https://docs.astral.sh/ruff/rules/cached-instance-method/
self.get_factual_value = lru_cache(self.cache_size)(self.get_factual_value) # type: ignore[method-assign]
self.get_target_value = lru_cache(self.cache_size)(self.get_target_value) # type: ignore[method-assign]
self.get_factual_value: Callable[[Constraint, sa.engine.Engine], Any] = (
lru_cache(self.cache_size)(self.get_factual_value)
)
self.get_target_value: Callable[[Constraint, sa.engine.Engine], Any] = (
lru_cache(self.cache_size)(self.get_target_value)
)

def _check_if_valid_between_or_within(
self,
Expand All @@ -176,13 +181,11 @@ def _check_if_valid_between_or_within(
f"{class_name}. Use exactly either of them."
)

# @lru_cache(maxsize=None), see _setup_caching()
def get_factual_value(self, engine: sa.engine.Engine) -> Any:
factual_value, factual_selections = self.retrieve(engine, self.ref)
self.factual_selections = factual_selections
return factual_value

# @lru_cache(maxsize=None), see _setup_caching()
def get_target_value(self, engine: sa.engine.Engine) -> Any:
if self.ref2 is None:
return self.ref_value
Expand Down Expand Up @@ -235,14 +238,18 @@ def retrieve(
self, engine: sa.engine.Engine, ref: DataReference
) -> tuple[Any, OptionalSelections]:
"""Retrieve the value of interest for a DataReference from database."""
pass
raise NotImplementedError()

def compare(self, value_factual: Any, value_target: Any) -> tuple[bool, str | None]:
pass
raise NotImplementedError()

def test(self, engine: sa.engine.Engine) -> TestResult:
value_factual = self.get_factual_value(engine)
value_target = self.get_target_value(engine)
# ty can't figure out that this is a method and that self is passed
# as the first argument.
value_factual = self.get_factual_value(engine=engine) # type: ignore[missing-argument]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this to be quite dissatisfying :(

# ty can't figure out that this is a method and that self is passed
# as the first argument.
value_target = self.get_target_value(engine=engine) # type: ignore[missing-argument]
is_success, assertion_message = self.compare(value_factual, value_target)
if is_success:
return TestResult.success()
Expand Down
2 changes: 1 addition & 1 deletion src/datajudge/constraints/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def retrieve(
# side effects. This should be removed as soon as snowflake column capitalization
# is fixed by snowflake-sqlalchemy.
if is_snowflake(engine) and self.ref_value is not None:
self.ref_value = lowercase_column_names(self.ref_value) # type: ignore
self.ref_value = lowercase_column_names(self.ref_value)
return db_access.get_column_names(engine, ref)


Expand Down
4 changes: 3 additions & 1 deletion src/datajudge/constraints/date.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def retrieve(
value, selections = db_access.get_max(engine, ref)
return convert_to_date(value, self.format), selections

def compare(self, max_factual: dt.date, max_target: dt.date) -> tuple[bool, str]:
def compare(
self, max_factual: dt.date, max_target: dt.date
) -> tuple[bool, str | None]:
if max_factual is None:
return True, None
if max_target is None:
Expand Down
2 changes: 1 addition & 1 deletion src/datajudge/constraints/miscs.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test(self, engine: sa.engine.Engine) -> TestResult:
self.target_selections = unique_selections
if row_count == 0:
return TestResult(True, "No occurrences.")
tolerance_kind, tolerance_value = self.ref_value # type: ignore
tolerance_kind, tolerance_value = self.ref_value
if tolerance_kind == "relative":
result = unique_count >= row_count * (1 - tolerance_value)
elif tolerance_kind == "absolute":
Expand Down
8 changes: 6 additions & 2 deletions src/datajudge/constraints/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,12 @@ def retrieve(
return result, selections

def test(self, engine: sa.engine.Engine) -> TestResult:
mean_factual = self.get_factual_value(engine)
mean_target = self.get_target_value(engine)
# ty can't figure out that this is a method and that self is passed
# as the first argument.
mean_factual = self.get_factual_value(engine=engine) # type: ignore[missing-argument]
# ty can't figure out that this is a method and that self is passed
# as the first argument.
mean_target = self.get_target_value(engine=engine) # type: ignore[missing-argument]
if mean_factual is None or mean_target is None:
return TestResult(
mean_factual is None and mean_target is None,
Expand Down
6 changes: 1 addition & 5 deletions src/datajudge/constraints/varchar.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,7 @@ def test(self, engine: sa.engine.Engine) -> TestResult:
return TestResult.failure("No regex pattern given")

pattern = re.compile(self.ref_value)
uniques_mismatching = {
x
for x in uniques_factual
if not pattern.match(x) # type: ignore
}
uniques_mismatching = {x for x in uniques_factual if not pattern.match(x)}

if self.aggregated:
n_violations = len(uniques_mismatching)
Expand Down
36 changes: 19 additions & 17 deletions src/datajudge/db_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import sqlalchemy as sa
from sqlalchemy.sql import selectable
from sqlalchemy.sql.expression import FromClause


def is_mssql(engine: sa.engine.Engine) -> bool:
Expand Down Expand Up @@ -211,12 +210,10 @@ def get_comparison_string(self, table_variable1: str, table_variable2: str) -> s

class DataSource(ABC):
@abstractmethod
def __str__(self) -> str:
pass
def __str__(self) -> str: ...

@abstractmethod
def get_clause(self, engine: sa.engine.Engine) -> FromClause:
pass
def get_clause(self, engine: sa.engine.Engine) -> sa.FromClause: ...


@functools.lru_cache(maxsize=1)
Expand All @@ -241,10 +238,10 @@ def __str__(self) -> str:
return f"{self.db_name}.{self.schema_name}.{self.table_name}"
return self.table_name

def get_clause(self, engine: sa.engine.Engine) -> FromClause:
def get_clause(self, engine: sa.engine.Engine) -> sa.Table:
schema = self.schema_name
if is_mssql(engine):
schema = self.db_name + "." + self.schema_name # type: ignore
if is_mssql(engine) and self.schema_name:
schema = self.db_name + "." + self.schema_name

return sa.Table(
self.table_name,
Expand All @@ -256,7 +253,7 @@ def get_clause(self, engine: sa.engine.Engine) -> FromClause:

@final
class ExpressionDataSource(DataSource):
def __init__(self, expression: FromClause | sa.Select, name: str):
def __init__(self, expression: sa.FromClause | sa.Select, name: str):
self.expression = expression
self.name = name

Expand All @@ -266,7 +263,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"{self.__class__.__name__}(expression={self.expression!r}, name={self.name})"

def get_clause(self, engine: sa.engine.Engine) -> FromClause:
def get_clause(self, engine: sa.engine.Engine) -> sa.FromClause:
return self.expression.alias()


Expand Down Expand Up @@ -294,7 +291,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"{self.__class__.__name__}(query_string={self.query_string}, name={self.name}, columns={self.columns})"

def get_clause(self, engine: sa.engine.Engine) -> FromClause:
def get_clause(self, engine: sa.engine.Engine) -> sa.FromClause:
return self.clause


Expand Down Expand Up @@ -488,7 +485,7 @@ def get_interval_overlaps_nd(
start_columns: list[str],
end_columns: list[str],
end_included: bool,
) -> tuple[sa.sql.selectable.CompoundSelect, sa.sql.selectable.Select]:
) -> tuple[sa.CompoundSelect, sa.Select]:
"""Create selectables for interval overlaps in n dimensions.

We define the presence of 'overlap' as presence of a non-empty intersection
Expand Down Expand Up @@ -1160,10 +1157,15 @@ def get_column_type(engine: sa.engine.Engine, ref: DataReference) -> tuple[Any,
def get_primary_keys(
engine: sa.engine.Engine, ref: DataReference
) -> tuple[list[str], None]:
table = ref.data_source.get_clause(engine)
# Kevin, 25/02/04
# Mypy complains about the following for a reason I can't follow.
return [column.name for column in table.primary_key.columns], None # type: ignore
data_source = ref.data_source

if isinstance(data_source, TableDataSource):
table = data_source.get_clause(engine)
return [column.name for column in table.primary_key.columns], None

raise NotImplementedError(
f"Cannot determine primary keys of a data source of type {type(data_source)}."
)


def get_row_difference_sample(
Expand Down Expand Up @@ -1247,7 +1249,7 @@ def get_row_mismatch(
return result_mismatch, result_n_rows, [selection_difference, selection_n_rows]


def duplicates(subquery: sa.sql.selectable.Subquery) -> sa.Select:
def duplicates(subquery: sa.Subquery) -> sa.Select:
aggregate_subquery = (
sa.select(subquery, sa.func.count().label("n_copies"))
.select_from(subquery)
Expand Down
2 changes: 1 addition & 1 deletion src/datajudge/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self):
self.known_bb_pattern = re.compile(STYLING_CODES)

# Just ignore styling in the default formatter
def apply_formatting(self, _: str, inner: str) -> str:
def apply_formatting(self, code: str, inner: str) -> str:
return inner

def fmt_str(self, string: str) -> str:
Expand Down
17 changes: 5 additions & 12 deletions src/datajudge/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,20 +1381,13 @@ def get_deviation_getter(
) -> Callable[[sa.engine.Engine], float]:
if fix_value is None and deviation is None:
raise ValueError("No valid gain/loss/deviation given.")
# The second predictate is redundant but appeases mypy since fix_value
# could, a priori, be None.
if deviation is None and fix_value is not None:
if deviation is None:
return lambda engine: fix_value
# The second predictate is redundant but appeases mypy since deviation
# could, a priori, be None.
if fix_value is None and deviation is not None:
if fix_value is None:
return lambda engine: self.get_date_growth_rate(engine) + deviation
# This clause is redundant but appeases mypy.
if fix_value is not None and deviation is not None:
return lambda engine: max(
fix_value, self.get_date_growth_rate(engine) + deviation
)
raise ValueError("No valid gain/loss/deviation given.")
return lambda engine: max(
fix_value, self.get_date_growth_rate(engine) + deviation
)

def add_n_rows_equality_constraint(
self,
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def get_engine(backend) -> sa.engine.Engine:
elif "snowflake" in backend:
# cryptography is a dependency of snowflake-connector,
# which is not present in the default environment
from cryptography.hazmat.primitives import serialization # type: ignore
from cryptography.hazmat.primitives import ( # type: ignore[unresolved-import]
serialization,
)

if not (private_key_env := os.getenv("SNOWFLAKE_PRIVATE_KEY")):
raise ValueError("SNOWFLAKE_PRIVATE_KEY environment variable is not set")
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ def test_n_rows_max_gain_between(engine, mix_table1, mix_table2, data):
condition1,
condition2,
) = data
req = requirements.BetweenRequirement.from_tables( # type: ignore[misc]
req = requirements.BetweenRequirement.from_tables(
*mix_table1,
*mix_table2,
date_column="col_date",
date_column2="col_date",
date_column="col_date", # type: ignore[parameter-already-assigned]
date_column2="col_date", # type: ignore[parameter-already-assigned]
)
req.add_n_rows_max_gain_constraint(
constant_max_relative_gain=constant_max_relative_gain,
Expand Down
11 changes: 4 additions & 7 deletions tests/integration/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest

import datajudge
from datajudge.constraints.stats import KolmogorovSmirnov2Sample
from datajudge.db_access import (
DataReference,
TableDataSource,
_cross_cdf_selection,
is_bigquery,
is_db2,
)
Expand All @@ -19,9 +20,7 @@ def test_cross_cdf_selection(engine, cross_cdf_table1, cross_cdf_table2):
tds2 = TableDataSource(database2, table2, schema2)
ref1 = DataReference(tds1, columns=["col_int"])
ref2 = DataReference(tds2, columns=["col_int"])
selection, _, _ = datajudge.db_access._cross_cdf_selection(
engine, ref1, ref2, "cdf", "value"
)
selection, _, _ = _cross_cdf_selection(engine, ref1, ref2, "cdf", "value")
with engine.connect() as connection:
result = connection.execute(selection).fetchall()
assert result is not None and len(result) > 0
Expand Down Expand Up @@ -61,9 +60,7 @@ def test_ks_2sample_calculate_statistic(engine, random_normal_table, configurati
n_samples,
m_samples,
_,
) = datajudge.constraints.stats.KolmogorovSmirnov2Sample.calculate_statistic(
engine, ref, ref2
)
) = KolmogorovSmirnov2Sample.calculate_statistic(engine, ref, ref2)

assert abs(d_statistic - expected_d) <= 1e-10, (
f"The test statistic does not match: {expected_d} vs {d_statistic}"
Expand Down
Loading