Skip to content

Commit c25d998

Browse files
committed
feat: add NullFilter and NotNullFilter for IS NULL filtering
Add NullFilter and NotNullFilter classes for filtering database queries where columns are NULL or NOT NULL, with framework integration for FastAPI and Litestar. - NullFilter generates WHERE column IS NULL clauses - NotNullFilter generates WHERE column IS NOT NULL clauses - FastAPI/Litestar providers support via null_fields/not_null_fields config - Builder API adds where_null() and where_not_null() aliases - 13 comprehensive unit tests added
1 parent 694761d commit c25d998

File tree

6 files changed

+412
-2
lines changed

6 files changed

+412
-2
lines changed

sqlspec/builder/_select.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,32 @@ def where_is_not_null(self, column: str | exp.Column) -> Self:
879879
condition: exp.Expression = col_expr.is_(exp.null()).not_()
880880
return self.where(condition)
881881

882+
def where_null(self, column: str | exp.Column) -> Self:
883+
"""Add WHERE column IS NULL clause.
884+
885+
Alias for where_is_null() for consistency with other SQL builders.
886+
887+
Args:
888+
column: Column name or expression to check for NULL.
889+
890+
Returns:
891+
Self for method chaining.
892+
"""
893+
return self.where_is_null(column)
894+
895+
def where_not_null(self, column: str | exp.Column) -> Self:
896+
"""Add WHERE column IS NOT NULL clause.
897+
898+
Alias for where_is_not_null() for consistency with other SQL builders.
899+
900+
Args:
901+
column: Column name or expression to check for NOT NULL.
902+
903+
Returns:
904+
Self for method chaining.
905+
"""
906+
return self.where_is_not_null(column)
907+
882908
def where_in(self, column: str | exp.Column, values: Any) -> Self:
883909
builder = cast("SQLBuilderProtocol", self)
884910
col_expr = parse_column_expression(column) if not isinstance(column, exp.Column) else column

sqlspec/core/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@
130130
InCollectionFilter,
131131
LimitOffsetFilter,
132132
NotInCollectionFilter,
133+
NotNullFilter,
134+
NullFilter,
133135
OrderByFilter,
134136
SearchFilter,
135137
StatementFilter,
@@ -225,6 +227,8 @@
225227
"LimitOffsetFilter",
226228
"MultiLevelCache",
227229
"NotInCollectionFilter",
230+
"NotNullFilter",
231+
"NullFilter",
228232
"OperationProfile",
229233
"OperationType",
230234
"OrderByFilter",

sqlspec/core/filters.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
"NotAnyCollectionFilter",
4646
"NotInCollectionFilter",
4747
"NotInSearchFilter",
48+
"NotNullFilter",
49+
"NullFilter",
4850
"OffsetPagination",
4951
"OnBeforeAfterFilter",
5052
"OrderByFilter",
@@ -728,6 +730,73 @@ def get_cache_key(self) -> tuple[Any, ...]:
728730
return ("SearchFilter", field_names, self.value, self.ignore_case)
729731

730732

733+
class NullFilter(StatementFilter):
734+
"""Filter for IS NULL queries.
735+
736+
Constructs WHERE field_name IS NULL clauses.
737+
"""
738+
739+
__slots__ = ("_field_name",)
740+
741+
def __init__(self, field_name: str) -> None:
742+
self._field_name = field_name
743+
744+
@property
745+
def field_name(self) -> str:
746+
return self._field_name
747+
748+
def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
749+
"""Extract filter parameters.
750+
751+
Returns empty parameters since IS NULL requires no values.
752+
"""
753+
return [], {}
754+
755+
def append_to_statement(self, statement: "SQL") -> "SQL":
756+
"""Apply IS NULL filter to SQL expression."""
757+
col_expr = exp.column(self.field_name)
758+
is_null_condition = exp.Is(this=col_expr, expression=exp.Null())
759+
return statement.where(is_null_condition)
760+
761+
def get_cache_key(self) -> tuple[Any, ...]:
762+
"""Return cache key for this filter configuration."""
763+
return ("NullFilter", self.field_name)
764+
765+
766+
class NotNullFilter(StatementFilter):
767+
"""Filter for IS NOT NULL queries.
768+
769+
Constructs WHERE field_name IS NOT NULL clauses.
770+
"""
771+
772+
__slots__ = ("_field_name",)
773+
774+
def __init__(self, field_name: str) -> None:
775+
self._field_name = field_name
776+
777+
@property
778+
def field_name(self) -> str:
779+
return self._field_name
780+
781+
def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
782+
"""Extract filter parameters.
783+
784+
Returns empty parameters since IS NOT NULL requires no values.
785+
"""
786+
return [], {}
787+
788+
def append_to_statement(self, statement: "SQL") -> "SQL":
789+
"""Apply IS NOT NULL filter to SQL expression."""
790+
col_expr = exp.column(self.field_name)
791+
is_null_condition = exp.Is(this=col_expr, expression=exp.Null())
792+
is_not_null_condition = exp.Not(this=is_null_condition)
793+
return statement.where(is_not_null_condition)
794+
795+
def get_cache_key(self) -> tuple[Any, ...]:
796+
"""Return cache key for this filter configuration."""
797+
return ("NotNullFilter", self.field_name)
798+
799+
731800
class NotInSearchFilter(SearchFilter):
732801
"""Filter for negated text search queries.
733802
@@ -836,6 +905,8 @@ def apply_filter(statement: "SQL", filter_obj: StatementFilter) -> "SQL":
836905
| NotInSearchFilter
837906
| AnyCollectionFilter[Any]
838907
| NotAnyCollectionFilter[Any]
908+
| NullFilter
909+
| NotNullFilter
839910
)
840911

841912

sqlspec/extensions/fastapi/providers.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
InCollectionFilter,
2020
LimitOffsetFilter,
2121
NotInCollectionFilter,
22+
NotNullFilter,
23+
NullFilter,
2224
OrderByFilter,
2325
SearchFilter,
2426
)
@@ -108,6 +110,10 @@ class FilterConfig(TypedDict):
108110
"""Fields that support not-in collection filtering. Can be single field or set of fields with type info."""
109111
in_fields: NotRequired[FieldNameType | set[FieldNameType]]
110112
"""Fields that support in-collection filtering. Can be single field or set of fields with type info."""
113+
null_fields: NotRequired[str | set[str]]
114+
"""Fields that support IS NULL filtering. Can be single field name or set of field names."""
115+
not_null_fields: NotRequired[str | set[str]]
116+
"""Fields that support IS NOT NULL filtering. Can be single field name or set of field names."""
111117

112118

113119
class DependencyCache(metaclass=SingletonMeta):
@@ -187,6 +193,8 @@ async def list_users(
187193
"sort_field",
188194
"not_in_fields",
189195
"in_fields",
196+
"null_fields",
197+
"not_null_fields",
190198
}
191199

192200
has_filters = False
@@ -512,6 +520,58 @@ def provide_in_filter(
512520
)
513521
annotations[param_name] = Annotated["InCollectionFilter[Any] | None", Depends(in_provider)]
514522

523+
if null_fields := config.get("null_fields"):
524+
null_fields = {null_fields} if isinstance(null_fields, str) else null_fields
525+
for field_name in null_fields:
526+
527+
def create_null_filter_provider(fname: str = field_name) -> "Callable[..., NullFilter | None]":
528+
def provide_null_filter(
529+
is_null: Annotated[
530+
bool | None,
531+
Query(alias=camelize(f"{fname}_is_null"), description=f"Filter where {fname} IS NULL"),
532+
] = None,
533+
) -> "NullFilter | None":
534+
return NullFilter(field_name=fname) if is_null else None
535+
536+
return provide_null_filter
537+
538+
null_provider = create_null_filter_provider()
539+
param_name = f"{field_name}_null_filter"
540+
params.append(
541+
inspect.Parameter(
542+
name=param_name,
543+
kind=inspect.Parameter.KEYWORD_ONLY,
544+
annotation=Annotated["NullFilter | None", Depends(null_provider)],
545+
)
546+
)
547+
annotations[param_name] = Annotated["NullFilter | None", Depends(null_provider)]
548+
549+
if not_null_fields := config.get("not_null_fields"):
550+
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else not_null_fields
551+
for field_name in not_null_fields:
552+
553+
def create_not_null_filter_provider(fname: str = field_name) -> "Callable[..., NotNullFilter | None]":
554+
def provide_not_null_filter(
555+
is_not_null: Annotated[
556+
bool | None,
557+
Query(alias=camelize(f"{fname}_is_not_null"), description=f"Filter where {fname} IS NOT NULL"),
558+
] = None,
559+
) -> "NotNullFilter | None":
560+
return NotNullFilter(field_name=fname) if is_not_null else None
561+
562+
return provide_not_null_filter
563+
564+
not_null_provider = create_not_null_filter_provider()
565+
param_name = f"{field_name}_not_null_filter"
566+
params.append(
567+
inspect.Parameter(
568+
name=param_name,
569+
kind=inspect.Parameter.KEYWORD_ONLY,
570+
annotation=Annotated["NotNullFilter | None", Depends(not_null_provider)],
571+
)
572+
)
573+
annotations[param_name] = Annotated["NotNullFilter | None", Depends(not_null_provider)]
574+
515575
_aggregate_filter_function.__signature__ = inspect.Signature( # type: ignore[attr-defined]
516576
parameters=params, return_annotation=Annotated["list[FilterTypes]", _aggregate_filter_function]
517577
)

sqlspec/extensions/litestar/providers.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
InCollectionFilter,
2121
LimitOffsetFilter,
2222
NotInCollectionFilter,
23+
NotNullFilter,
24+
NullFilter,
2325
OrderByFilter,
2426
SearchFilter,
2527
)
@@ -91,6 +93,10 @@ class FilterConfig(TypedDict):
9193
updated_at: NotRequired[bool]
9294
not_in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]]
9395
in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]]
96+
null_fields: NotRequired[str | set[str] | list[str]]
97+
"""Fields that support IS NULL filtering."""
98+
not_null_fields: NotRequired[str | set[str] | list[str]]
99+
"""Fields that support IS NOT NULL filtering."""
94100

95101

96102
class DependencyCache(metaclass=SingletonMeta):
@@ -152,7 +158,7 @@ def _make_hashable(value: Any) -> HashableType:
152158
return str(value)
153159

154160

155-
def _create_statement_filters(
161+
def _create_statement_filters( # noqa: C901
156162
config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS
157163
) -> dict[str, Provide]:
158164
"""Create filter dependencies based on configuration.
@@ -294,6 +300,40 @@ def provide_in_filter( # pyright: ignore
294300
provider = create_in_filter_provider(field_def) # type: ignore
295301
filters[f"{field_def.name}_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
296302

303+
if null_fields := config.get("null_fields"):
304+
null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields)
305+
306+
for field_name in null_fields:
307+
308+
def create_null_filter_provider(fname: str) -> Callable[..., NullFilter | None]:
309+
def provide_null_filter(
310+
is_null: bool | None = Parameter(query=camelize(f"{fname}_is_null"), default=None, required=False),
311+
) -> NullFilter | None:
312+
return NullFilter(field_name=fname) if is_null else None
313+
314+
return provide_null_filter
315+
316+
null_provider = create_null_filter_provider(field_name)
317+
filters[f"{field_name}_null_filter"] = Provide(null_provider, sync_to_thread=False)
318+
319+
if not_null_fields := config.get("not_null_fields"):
320+
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields)
321+
322+
for field_name in not_null_fields:
323+
324+
def create_not_null_filter_provider(fname: str) -> Callable[..., NotNullFilter | None]:
325+
def provide_not_null_filter(
326+
is_not_null: bool | None = Parameter(
327+
query=camelize(f"{fname}_is_not_null"), default=None, required=False
328+
),
329+
) -> NotNullFilter | None:
330+
return NotNullFilter(field_name=fname) if is_not_null else None
331+
332+
return provide_not_null_filter
333+
334+
not_null_provider = create_not_null_filter_provider(field_name)
335+
filters[f"{field_name}_not_null_filter"] = Provide(not_null_provider, sync_to_thread=False)
336+
297337
if filters:
298338
filters[dep_defaults.FILTERS_DEPENDENCY_KEY] = Provide(
299339
_create_filter_aggregate_function(config), sync_to_thread=False
@@ -302,7 +342,7 @@ def provide_in_filter( # pyright: ignore
302342
return filters
303343

304344

305-
def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., list[FilterTypes]]:
345+
def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., list[FilterTypes]]: # noqa: C901
306346
"""Create filter aggregation function based on configuration.
307347
308348
Args:
@@ -391,6 +431,28 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
391431
)
392432
annotations[f"{field_def.name}_in_filter"] = InCollectionFilter[field_def.type_hint] # type: ignore
393433

434+
if null_fields := config.get("null_fields"):
435+
null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields)
436+
for field_name in null_fields:
437+
parameters[f"{field_name}_null_filter"] = inspect.Parameter(
438+
name=f"{field_name}_null_filter",
439+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
440+
default=Dependency(skip_validation=True),
441+
annotation=NullFilter | None,
442+
)
443+
annotations[f"{field_name}_null_filter"] = NullFilter | None
444+
445+
if not_null_fields := config.get("not_null_fields"):
446+
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields)
447+
for field_name in not_null_fields:
448+
parameters[f"{field_name}_not_null_filter"] = inspect.Parameter(
449+
name=f"{field_name}_not_null_filter",
450+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
451+
default=Dependency(skip_validation=True),
452+
annotation=NotNullFilter | None,
453+
)
454+
annotations[f"{field_name}_not_null_filter"] = NotNullFilter | None
455+
394456
def provide_filters(**kwargs: FilterTypes) -> list[FilterTypes]:
395457
"""Aggregate filter dependencies based on configuration.
396458
@@ -438,6 +500,21 @@ def provide_filters(**kwargs: FilterTypes) -> list[FilterTypes]:
438500
filter_ = kwargs.get(f"{field_def.name}_in_filter")
439501
if filter_ is not None:
440502
filters.append(filter_)
503+
504+
if null_fields := config.get("null_fields"):
505+
null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields)
506+
for field_name in null_fields:
507+
filter_ = kwargs.get(f"{field_name}_null_filter")
508+
if filter_ is not None:
509+
filters.append(filter_)
510+
511+
if not_null_fields := config.get("not_null_fields"):
512+
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields)
513+
for field_name in not_null_fields:
514+
filter_ = kwargs.get(f"{field_name}_not_null_filter")
515+
if filter_ is not None:
516+
filters.append(filter_)
517+
441518
return filters
442519

443520
provide_filters.__signature__ = inspect.Signature( # type: ignore

0 commit comments

Comments
 (0)