Skip to content

Commit ae6c893

Browse files
Benjamin Gutzmanngutzbenj
authored andcommitted
Enhance query handling for nested columns and date filters
- Improve nested/dotted column handling: alias nested columns with underscores for consistent querying when data is loaded via database accessors. - Consolidate date-type filters to be ORed together when fetching data into memory, allowing multiple date conditions to apply correctly. - Add validation to ensure filters with values specify corresponding columns.
1 parent d383522 commit ae6c893

File tree

7 files changed

+618
-42
lines changed

7 files changed

+618
-42
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@ Types of changes:
1616

1717
## [Unreleased]
1818

19+
### Fixed
20+
21+
- Improve nested/dotted column handling: when data is loaded from external sources via database accessors, nested
22+
columns (e.g., `value.shopId`) are now aliased with underscores (`value_shopId`) for consistent querying. Native
23+
DuckDB struct columns continue to use dot notation. This ensures proper handling in WHERE clauses, JOIN conditions,
24+
and SELECT statements across all check types.
25+
- Consolidate date-type filters to be ORed together (instead of ANDed) when fetching data into memory, allowing multiple
26+
date conditions to apply correctly.
27+
1928
## [0.11.2] - 2026-01-23
2029

2130
### Fixed

src/koality/checks.py

Lines changed: 78 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,20 @@ def __init__(
114114
def in_memory_column(self) -> str:
115115
"""Return the column name to reference in in-memory queries.
116116
117-
If a configured column references a nested field (e.g. "value.shopId"),
118-
the in-memory representation uses the last segment ("shopId"). This
119-
property provides that flattened name without modifying the original
117+
If a configured column references a nested field (e.g. "value.shopId"):
118+
- When querying data loaded via database_accessor: uses underscores ("value_shopId")
119+
because the executor flattens struct columns with underscore aliases
120+
- When querying existing DuckDB tables (no accessor): keeps dots ("value.shopId")
121+
to support native DuckDB struct column syntax
122+
123+
This property provides the appropriate name without modifying the original
120124
configured `self.check_column` which is still used for result writing.
121125
"""
122-
if isinstance(self.check_column, str) and "." in self.check_column:
123-
return self.check_column.split(".")[-1]
126+
if isinstance(self.check_column, str) and "." in self.check_column: # noqa: SIM102
127+
# Only convert to underscores if data was loaded via database_accessor
128+
# (which flattens structs). For native DuckDB tables, keep dotted notation.
129+
if self.database_accessor:
130+
return self.check_column.replace(".", "_")
124131
return self.check_column
125132

126133
@property
@@ -391,15 +398,25 @@ def get_identifier_filter(filters: dict[str, dict[str, Any]]) -> tuple[str, dict
391398
return None
392399

393400
@staticmethod
394-
def assemble_where_statement(filters: dict[str, dict[str, Any]], *, strip_dotted_columns: bool = True) -> str:
401+
def assemble_where_statement( # noqa: C901
402+
filters: dict[str, dict[str, Any]],
403+
*,
404+
strip_dotted_columns: bool = True,
405+
database_accessor: str | None = None,
406+
) -> str:
395407
"""Generate the where statement for the check query using the specified filters.
396408
397409
Args:
398410
filters: A dict containing filter specifications, e.g.,
399411
strip_dotted_columns: When True (default), dotted column names (e.g. "a.b") are
400-
reduced to their last component ("b") for WHERE clauses. If False, the full
401-
dotted expression is preserved. This is useful when querying source databases
402-
that expect the original dotted column syntax.
412+
transformed based on the data source:
413+
- With database_accessor: converted to underscores ("a_b") for flattened data
414+
- Without database_accessor: kept as dots ("a.b") for native DuckDB structs
415+
If False, the full dotted expression is preserved regardless (used when
416+
querying source databases that expect the original dotted column syntax).
417+
database_accessor: Optional database accessor string. When provided and non-empty,
418+
indicates data was loaded from external source and dotted columns should be
419+
converted to underscores.
403420
404421
Example filters:
405422
`{
@@ -445,17 +462,22 @@ def assemble_where_statement(filters: dict[str, dict[str, Any]], *, strip_dotted
445462
# If column is not provided we cannot build a WHERE condition
446463
if column is None:
447464
continue
448-
# If the column references a nested field (e.g. "value.shopId"),
449-
# databases that flatten JSON may have the column stored without the
450-
# prefix. By default we use the last component after the dot for the
451-
# WHERE clause, but callers can disable this behavior by setting
452-
# `strip_dotted_columns=False` (used when querying source DBs so the
453-
# original dotted expression is preserved).
454-
if isinstance(column, str) and "." in column and strip_dotted_columns:
455-
column = column.split(".")[-1]
465+
# If the column references a nested field (e.g. "value.shopId"):
466+
# - With database_accessor: convert to underscores (value_shopId) for flattened data
467+
# - Without database_accessor: keep dots (value.shopId) for native DuckDB structs
468+
# Callers can disable this by setting `strip_dotted_columns=False`
469+
# (used when querying source DBs where the original dotted expression is needed).
470+
if isinstance(column, str) and "." in column and strip_dotted_columns and database_accessor:
471+
column = column.replace(".", "_")
472+
# else: keep dotted notation for native DuckDB struct support
456473

457474
operator = filter_dict.get("operator", "=")
458475

476+
# Cast date columns for proper comparison
477+
is_date_filter = filter_dict.get("type") == "date"
478+
if is_date_filter:
479+
column = f"CAST({column} AS DATE)"
480+
459481
# Handle NULL values with IS NULL / IS NOT NULL
460482
if value is None:
461483
if operator == "!=":
@@ -465,6 +487,9 @@ def assemble_where_statement(filters: dict[str, dict[str, Any]], *, strip_dotted
465487
continue
466488

467489
formatted_value = format_filter_value(value, operator)
490+
# Prefix DATE for date type filters
491+
if is_date_filter and operator not in ("BETWEEN", "IN", "NOT IN"):
492+
formatted_value = f"DATE {formatted_value}"
468493
filters_statements.append(f" {column} {operator} {formatted_value}")
469494

470495
if len(filters_statements) == 0:
@@ -547,7 +572,7 @@ def assemble_query(self) -> str:
547572
if isinstance(self, IqrOutlierCheck):
548573
filters = {name: cfg for name, cfg in filters.items() if cfg.get("type") != "date"}
549574

550-
if where_statement := self.assemble_where_statement(filters):
575+
if where_statement := self.assemble_where_statement(filters, database_accessor=self.database_accessor):
551576
return main_query + "\n" + where_statement
552577

553578
return main_query
@@ -561,7 +586,7 @@ def assemble_data_exists_query(self) -> str:
561586
"{self.table}"
562587
"""
563588

564-
if where_statement := self.assemble_where_statement(self.filters):
589+
if where_statement := self.assemble_where_statement(self.filters, database_accessor=self.database_accessor):
565590
return f"{data_exists_query}\n{where_statement}"
566591

567592
return data_exists_query
@@ -861,7 +886,7 @@ def assemble_query(self) -> str:
861886
f"CAST({date_col} AS DATE) BETWEEN (DATE '{date_val}' - INTERVAL 14 DAY) AND DATE '{date_val}'"
862887
) # TODO: maybe parameterize interval days
863888

864-
if where_statement := self.assemble_where_statement(self.filters):
889+
if where_statement := self.assemble_where_statement(self.filters, database_accessor=self.database_accessor):
865890
return main_query + "\nAND\n" + where_statement.removeprefix("WHERE\n")
866891

867892
return main_query
@@ -1220,7 +1245,7 @@ def assemble_query(self) -> str:
12201245
order = {"max": "DESC", "min": "ASC"}[self.max_or_min]
12211246
return f"""
12221247
{self.query_boilerplate(self.transformation_statement())}
1223-
{self.assemble_where_statement(self.filters)}
1248+
{self.assemble_where_statement(self.filters, database_accessor=self.database_accessor)}
12241249
GROUP BY {self.in_memory_column}
12251250
ORDER BY {self.name} {order}
12261251
LIMIT 1 -- only the first entry is needed
@@ -1343,14 +1368,27 @@ def assemble_name(self) -> str:
13431368

13441369
def assemble_query(self) -> str:
13451370
"""Assemble the SQL query for calculating match rate between tables."""
1346-
right_column_statement = ",\n ".join(self.join_columns_right)
1347-
1348-
join_on_statement = "\n AND\n ".join(
1349-
[
1350-
f"lefty.{left_col} = righty.{right_col.split('.')[-1]}"
1351-
for left_col, right_col in zip(self.join_columns_left, self.join_columns_right, strict=False)
1352-
],
1353-
)
1371+
# Transform dotted column names based on data source:
1372+
# - With database_accessor: convert to underscores (value.shopId → value_shopId) for flattened data
1373+
# - Without database_accessor: keep dots for SELECT (value.shopId), use last part for JOIN (shopId)
1374+
if self.database_accessor:
1375+
right_column_statement = ",\n ".join([col.replace(".", "_") for col in self.join_columns_right])
1376+
join_on_statement = "\n AND\n ".join(
1377+
[
1378+
f"lefty.{left_col.replace('.', '_')} = righty.{right_col.replace('.', '_')}"
1379+
for left_col, right_col in zip(self.join_columns_left, self.join_columns_right, strict=False)
1380+
],
1381+
)
1382+
else:
1383+
# For native DuckDB struct columns: SELECT uses dotted notation,
1384+
# but DuckDB names the result column as just the last part
1385+
right_column_statement = ",\n ".join(self.join_columns_right)
1386+
join_on_statement = "\n AND\n ".join(
1387+
[
1388+
f"lefty.{left_col.split('.')[-1]} = righty.{right_col.split('.')[-1]}"
1389+
for left_col, right_col in zip(self.join_columns_left, self.join_columns_right, strict=False)
1390+
],
1391+
)
13541392

13551393
return f"""
13561394
WITH
@@ -1360,14 +1398,14 @@ def assemble_query(self) -> str:
13601398
TRUE AS in_right_table
13611399
FROM
13621400
"{self.right_table}"
1363-
{self.assemble_where_statement(self.filters_right)}
1401+
{self.assemble_where_statement(self.filters_right, database_accessor=self.database_accessor)}
13641402
),
13651403
lefty AS (
13661404
SELECT
13671405
*
13681406
FROM
13691407
"{self.left_table}"
1370-
{self.assemble_where_statement(self.filters_left)}
1408+
{self.assemble_where_statement(self.filters_left, database_accessor=self.database_accessor)}
13711409
)
13721410
13731411
SELECT
@@ -1397,15 +1435,15 @@ def assemble_data_exists_query(self) -> str:
13971435
COUNT(*) AS right_counter,
13981436
FROM
13991437
"{self.right_table}"
1400-
{self.assemble_where_statement(self.filters_right)}
1438+
{self.assemble_where_statement(self.filters_right, database_accessor=self.database_accessor)}
14011439
),
14021440
14031441
lefty AS (
14041442
SELECT
14051443
COUNT(*) AS left_counter,
14061444
FROM
14071445
"{self.left_table}"
1408-
{self.assemble_where_statement(self.filters_left)}
1446+
{self.assemble_where_statement(self.filters_left, database_accessor=self.database_accessor)}
14091447
)
14101448
14111449
SELECT
@@ -1508,7 +1546,10 @@ def assemble_name(self) -> str:
15081546

15091547
def assemble_query(self) -> str:
15101548
"""Assemble the SQL query for calculating relative count change."""
1511-
where_statement = self.assemble_where_statement(self.filters).replace("WHERE", "AND")
1549+
where_statement = self.assemble_where_statement(self.filters, database_accessor=self.database_accessor).replace(
1550+
"WHERE",
1551+
"AND",
1552+
)
15121553
date_col = self.date_filter["column"]
15131554
date_val = self.date_filter["value"]
15141555

@@ -1573,7 +1614,7 @@ def assemble_data_exists_query(self) -> str:
15731614
date_col = self.date_filter["column"]
15741615
date_val = self.date_filter["value"]
15751616

1576-
where_statement = self.assemble_where_statement(self.filters)
1617+
where_statement = self.assemble_where_statement(self.filters, database_accessor=self.database_accessor)
15771618
if where_statement:
15781619
return f"{data_exists_query}\n{where_statement} AND CAST({date_col} AS DATE) = DATE '{date_val}'"
15791620
return f"{data_exists_query}\nWHERE CAST({date_col} AS DATE) = DATE '{date_val}'"
@@ -1693,7 +1734,7 @@ def transformation_statement(self) -> str:
16931734
if filters:
16941735
filter_columns = ",\n".join([v["column"] for v in filters.values()])
16951736
filter_columns = ",\n" + filter_columns
1696-
where_statement = self.assemble_where_statement(filters)
1737+
where_statement = self.assemble_where_statement(filters, database_accessor=self.database_accessor)
16971738
where_statement = "\nAND\n" + where_statement.removeprefix("WHERE\n")
16981739
return f"""
16991740
WITH
@@ -1770,7 +1811,7 @@ def assemble_data_exists_query(self) -> str:
17701811

17711812
filters = {k: v for k, v in self.filters.items() if v["type"] != "date"}
17721813

1773-
where_statement = self.assemble_where_statement(filters)
1814+
where_statement = self.assemble_where_statement(filters, database_accessor=self.database_accessor)
17741815
if where_statement:
17751816
where_statement = f"{where_statement} AND CAST({date_col} AS DATE) = DATE '{date_val}'"
17761817
else:

src/koality/executor.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,9 @@ def fetch_data_into_memory(self, data_requirements: defaultdict[str, defaultdict
340340
select_parts.append("*")
341341
continue
342342
if isinstance(col, str) and "." in col:
343-
flat = col.split(".")[-1]
343+
# Replace dots with underscores for deterministic aliasing
344+
# e.g., "value.shopId" becomes "value_shopId"
345+
flat = col.replace(".", "_")
344346
# Make flattened name unique if duplicate arises
345347
base = flat
346348
idx = 1
@@ -364,13 +366,15 @@ def fetch_data_into_memory(self, data_requirements: defaultdict[str, defaultdict
364366
select_parts.append(col)
365367
columns = ", ".join(select_parts)
366368

367-
# Combine all unique filter groups. Treat date-range filters specially and
368-
# combine them with other filters using AND (date range applies to all other filters).
369+
# Combine all unique filter groups. Separate date filters from other filters.
370+
# All date-related conditions (BETWEEN ranges and date equality) should be ORed.
371+
# Non-date filters should be ANDed with the date conditions.
369372
date_filters_sql = set()
370373
other_filters_sql = set()
371374

372375
for filter_group in requirements["filters"]:
373376
filter_dict = {}
377+
date_filter_dict = {}
374378
for item in filter_group:
375379
# Expect each item to be a (name, frozenset(cfg_items)) tuple
376380
if not (isinstance(item, tuple) and len(item) == _DATE_RANGE_TUPLE_SIZE):
@@ -390,8 +394,24 @@ def fetch_data_into_memory(self, data_requirements: defaultdict[str, defaultdict
390394
date_filters_sql.add(f"({cond})")
391395
# date_range handled; continue to next group
392396
continue
393-
filter_dict[name] = dict(cfg)
397+
# Separate date-type filters from other filters
398+
if cfg.get("type") == "date":
399+
date_filter_dict[name] = dict(cfg)
400+
else:
401+
filter_dict[name] = dict(cfg)
402+
403+
# Process date filters separately and add to date_filters_sql
404+
if date_filter_dict:
405+
where_clause = DataQualityCheck.assemble_where_statement(
406+
date_filter_dict,
407+
strip_dotted_columns=False,
408+
)
409+
if where_clause.strip().startswith("WHERE"):
410+
conditions = where_clause.strip()[len("WHERE") :].strip()
411+
if conditions:
412+
date_filters_sql.add(f"({conditions})")
394413

414+
# Process non-date filters
395415
if filter_dict:
396416
# When fetching from the source DB, preserve dotted column expressions
397417
# (e.g., "value.shopId") in the WHERE so the source provider sees the
@@ -403,7 +423,7 @@ def fetch_data_into_memory(self, data_requirements: defaultdict[str, defaultdict
403423
if conditions:
404424
other_filters_sql.add(f"({conditions})")
405425

406-
# Build final WHERE clause: if we have date filters, AND them with other filters (if any).
426+
# Build final WHERE clause: OR all date filters together, AND with other filters.
407427
final_where_clause = ""
408428
if date_filters_sql and other_filters_sql:
409429
date_part = " OR ".join(sorted(date_filters_sql))

src/koality/models.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,32 @@ def persist_results(self) -> bool:
192192
class _Check(_LocalDefaults):
193193
"""Base model for all check configurations."""
194194

195+
@model_validator(mode="after")
196+
def validate_filters_have_columns(self) -> Self:
197+
"""Validate that all filters with concrete values have columns specified.
198+
199+
This validation runs after defaults merging, ensuring the final filter
200+
configuration is complete.
201+
"""
202+
for filter_name, filter_config in self.filters.items():
203+
# Skip identifier filters without concrete values (naming-only)
204+
if filter_config.type == "identifier" and (filter_config.value is None or filter_config.value == "*"):
205+
continue
206+
207+
# Skip partial filters with no value
208+
if filter_config.value is None:
209+
continue
210+
211+
# Filter has a value but no column - this is an error
212+
if filter_config.column is None:
213+
msg = (
214+
f"Filter '{filter_name}' has value '{filter_config.value}' "
215+
f"but no column specified. Add 'column: <column_name>' to the filter definition."
216+
)
217+
raise ValueError(msg)
218+
219+
return self
220+
195221

196222
class _SingleTableCheck(_Check):
197223
"""Base model for checks that operate on a single table."""

0 commit comments

Comments
 (0)