Skip to content

Commit daca7b2

Browse files
authored
feat: add NullFilter and NotNullFilter for IS NULL filtering (#290)
Add NullFilter and NotNullFilter classes for filtering database queries where columns are NULL or NOT NULL, with full framework integration for FastAPI and Litestar.
1 parent 694761d commit daca7b2

File tree

5 files changed

+386
-2
lines changed

5 files changed

+386
-2
lines changed

sqlspec/core/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@
130130
InCollectionFilter,
131131
LimitOffsetFilter,
132132
NotInCollectionFilter,
133+
NotNullFilter,
134+
NullFilter,
133135
OrderByFilter,
134136
SearchFilter,
135137
StatementFilter,
@@ -225,6 +227,8 @@
225227
"LimitOffsetFilter",
226228
"MultiLevelCache",
227229
"NotInCollectionFilter",
230+
"NotNullFilter",
231+
"NullFilter",
228232
"OperationProfile",
229233
"OperationType",
230234
"OrderByFilter",

sqlspec/core/filters.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
"NotAnyCollectionFilter",
4646
"NotInCollectionFilter",
4747
"NotInSearchFilter",
48+
"NotNullFilter",
49+
"NullFilter",
4850
"OffsetPagination",
4951
"OnBeforeAfterFilter",
5052
"OrderByFilter",
@@ -728,6 +730,73 @@ def get_cache_key(self) -> tuple[Any, ...]:
728730
return ("SearchFilter", field_names, self.value, self.ignore_case)
729731

730732

733+
class NullFilter(StatementFilter):
734+
"""Filter for IS NULL queries.
735+
736+
Constructs WHERE field_name IS NULL clauses.
737+
"""
738+
739+
__slots__ = ("_field_name",)
740+
741+
def __init__(self, field_name: str) -> None:
742+
self._field_name = field_name
743+
744+
@property
745+
def field_name(self) -> str:
746+
return self._field_name
747+
748+
def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
749+
"""Extract filter parameters.
750+
751+
Returns empty parameters since IS NULL requires no values.
752+
"""
753+
return [], {}
754+
755+
def append_to_statement(self, statement: "SQL") -> "SQL":
756+
"""Apply IS NULL filter to SQL expression."""
757+
col_expr = exp.column(self.field_name)
758+
is_null_condition = exp.Is(this=col_expr, expression=exp.Null())
759+
return statement.where(is_null_condition)
760+
761+
def get_cache_key(self) -> tuple[Any, ...]:
762+
"""Return cache key for this filter configuration."""
763+
return ("NullFilter", self.field_name)
764+
765+
766+
class NotNullFilter(StatementFilter):
767+
"""Filter for IS NOT NULL queries.
768+
769+
Constructs WHERE field_name IS NOT NULL clauses.
770+
"""
771+
772+
__slots__ = ("_field_name",)
773+
774+
def __init__(self, field_name: str) -> None:
775+
self._field_name = field_name
776+
777+
@property
778+
def field_name(self) -> str:
779+
return self._field_name
780+
781+
def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
782+
"""Extract filter parameters.
783+
784+
Returns empty parameters since IS NOT NULL requires no values.
785+
"""
786+
return [], {}
787+
788+
def append_to_statement(self, statement: "SQL") -> "SQL":
789+
"""Apply IS NOT NULL filter to SQL expression."""
790+
col_expr = exp.column(self.field_name)
791+
is_null_condition = exp.Is(this=col_expr, expression=exp.Null())
792+
is_not_null_condition = exp.Not(this=is_null_condition)
793+
return statement.where(is_not_null_condition)
794+
795+
def get_cache_key(self) -> tuple[Any, ...]:
796+
"""Return cache key for this filter configuration."""
797+
return ("NotNullFilter", self.field_name)
798+
799+
731800
class NotInSearchFilter(SearchFilter):
732801
"""Filter for negated text search queries.
733802
@@ -836,6 +905,8 @@ def apply_filter(statement: "SQL", filter_obj: StatementFilter) -> "SQL":
836905
| NotInSearchFilter
837906
| AnyCollectionFilter[Any]
838907
| NotAnyCollectionFilter[Any]
908+
| NullFilter
909+
| NotNullFilter
839910
)
840911

841912

sqlspec/extensions/fastapi/providers.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
InCollectionFilter,
2020
LimitOffsetFilter,
2121
NotInCollectionFilter,
22+
NotNullFilter,
23+
NullFilter,
2224
OrderByFilter,
2325
SearchFilter,
2426
)
@@ -108,6 +110,10 @@ class FilterConfig(TypedDict):
108110
"""Fields that support not-in collection filtering. Can be single field or set of fields with type info."""
109111
in_fields: NotRequired[FieldNameType | set[FieldNameType]]
110112
"""Fields that support in-collection filtering. Can be single field or set of fields with type info."""
113+
null_fields: NotRequired[str | set[str]]
114+
"""Fields that support IS NULL filtering. Can be single field name or set of field names."""
115+
not_null_fields: NotRequired[str | set[str]]
116+
"""Fields that support IS NOT NULL filtering. Can be single field name or set of field names."""
111117

112118

113119
class DependencyCache(metaclass=SingletonMeta):
@@ -187,6 +193,8 @@ async def list_users(
187193
"sort_field",
188194
"not_in_fields",
189195
"in_fields",
196+
"null_fields",
197+
"not_null_fields",
190198
}
191199

192200
has_filters = False
@@ -512,6 +520,58 @@ def provide_in_filter(
512520
)
513521
annotations[param_name] = Annotated["InCollectionFilter[Any] | None", Depends(in_provider)]
514522

523+
if null_fields := config.get("null_fields"):
524+
null_fields = {null_fields} if isinstance(null_fields, str) else null_fields
525+
for field_name in null_fields:
526+
527+
def create_null_filter_provider(fname: str = field_name) -> "Callable[..., NullFilter | None]":
528+
def provide_null_filter(
529+
is_null: Annotated[
530+
bool | None,
531+
Query(alias=camelize(f"{fname}_is_null"), description=f"Filter where {fname} IS NULL"),
532+
] = None,
533+
) -> "NullFilter | None":
534+
return NullFilter(field_name=fname) if is_null else None
535+
536+
return provide_null_filter
537+
538+
null_provider = create_null_filter_provider()
539+
param_name = f"{field_name}_null_filter"
540+
params.append(
541+
inspect.Parameter(
542+
name=param_name,
543+
kind=inspect.Parameter.KEYWORD_ONLY,
544+
annotation=Annotated["NullFilter | None", Depends(null_provider)],
545+
)
546+
)
547+
annotations[param_name] = Annotated["NullFilter | None", Depends(null_provider)]
548+
549+
if not_null_fields := config.get("not_null_fields"):
550+
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else not_null_fields
551+
for field_name in not_null_fields:
552+
553+
def create_not_null_filter_provider(fname: str = field_name) -> "Callable[..., NotNullFilter | None]":
554+
def provide_not_null_filter(
555+
is_not_null: Annotated[
556+
bool | None,
557+
Query(alias=camelize(f"{fname}_is_not_null"), description=f"Filter where {fname} IS NOT NULL"),
558+
] = None,
559+
) -> "NotNullFilter | None":
560+
return NotNullFilter(field_name=fname) if is_not_null else None
561+
562+
return provide_not_null_filter
563+
564+
not_null_provider = create_not_null_filter_provider()
565+
param_name = f"{field_name}_not_null_filter"
566+
params.append(
567+
inspect.Parameter(
568+
name=param_name,
569+
kind=inspect.Parameter.KEYWORD_ONLY,
570+
annotation=Annotated["NotNullFilter | None", Depends(not_null_provider)],
571+
)
572+
)
573+
annotations[param_name] = Annotated["NotNullFilter | None", Depends(not_null_provider)]
574+
515575
_aggregate_filter_function.__signature__ = inspect.Signature( # type: ignore[attr-defined]
516576
parameters=params, return_annotation=Annotated["list[FilterTypes]", _aggregate_filter_function]
517577
)

sqlspec/extensions/litestar/providers.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
InCollectionFilter,
2121
LimitOffsetFilter,
2222
NotInCollectionFilter,
23+
NotNullFilter,
24+
NullFilter,
2325
OrderByFilter,
2426
SearchFilter,
2527
)
@@ -91,6 +93,10 @@ class FilterConfig(TypedDict):
9193
updated_at: NotRequired[bool]
9294
not_in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]]
9395
in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]]
96+
null_fields: NotRequired[str | set[str] | list[str]]
97+
"""Fields that support IS NULL filtering."""
98+
not_null_fields: NotRequired[str | set[str] | list[str]]
99+
"""Fields that support IS NOT NULL filtering."""
94100

95101

96102
class DependencyCache(metaclass=SingletonMeta):
@@ -152,7 +158,7 @@ def _make_hashable(value: Any) -> HashableType:
152158
return str(value)
153159

154160

155-
def _create_statement_filters(
161+
def _create_statement_filters( # noqa: C901
156162
config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS
157163
) -> dict[str, Provide]:
158164
"""Create filter dependencies based on configuration.
@@ -294,6 +300,40 @@ def provide_in_filter( # pyright: ignore
294300
provider = create_in_filter_provider(field_def) # type: ignore
295301
filters[f"{field_def.name}_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore
296302

303+
if null_fields := config.get("null_fields"):
304+
null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields)
305+
306+
for field_name in null_fields:
307+
308+
def create_null_filter_provider(fname: str) -> Callable[..., NullFilter | None]:
309+
def provide_null_filter(
310+
is_null: bool | None = Parameter(query=camelize(f"{fname}_is_null"), default=None, required=False),
311+
) -> NullFilter | None:
312+
return NullFilter(field_name=fname) if is_null else None
313+
314+
return provide_null_filter
315+
316+
null_provider = create_null_filter_provider(field_name)
317+
filters[f"{field_name}_null_filter"] = Provide(null_provider, sync_to_thread=False)
318+
319+
if not_null_fields := config.get("not_null_fields"):
320+
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields)
321+
322+
for field_name in not_null_fields:
323+
324+
def create_not_null_filter_provider(fname: str) -> Callable[..., NotNullFilter | None]:
325+
def provide_not_null_filter(
326+
is_not_null: bool | None = Parameter(
327+
query=camelize(f"{fname}_is_not_null"), default=None, required=False
328+
),
329+
) -> NotNullFilter | None:
330+
return NotNullFilter(field_name=fname) if is_not_null else None
331+
332+
return provide_not_null_filter
333+
334+
not_null_provider = create_not_null_filter_provider(field_name)
335+
filters[f"{field_name}_not_null_filter"] = Provide(not_null_provider, sync_to_thread=False)
336+
297337
if filters:
298338
filters[dep_defaults.FILTERS_DEPENDENCY_KEY] = Provide(
299339
_create_filter_aggregate_function(config), sync_to_thread=False
@@ -302,7 +342,7 @@ def provide_in_filter( # pyright: ignore
302342
return filters
303343

304344

305-
def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., list[FilterTypes]]:
345+
def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., list[FilterTypes]]: # noqa: C901
306346
"""Create filter aggregation function based on configuration.
307347
308348
Args:
@@ -391,6 +431,28 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
391431
)
392432
annotations[f"{field_def.name}_in_filter"] = InCollectionFilter[field_def.type_hint] # type: ignore
393433

434+
if null_fields := config.get("null_fields"):
435+
null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields)
436+
for field_name in null_fields:
437+
parameters[f"{field_name}_null_filter"] = inspect.Parameter(
438+
name=f"{field_name}_null_filter",
439+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
440+
default=Dependency(skip_validation=True),
441+
annotation=NullFilter | None,
442+
)
443+
annotations[f"{field_name}_null_filter"] = NullFilter | None
444+
445+
if not_null_fields := config.get("not_null_fields"):
446+
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields)
447+
for field_name in not_null_fields:
448+
parameters[f"{field_name}_not_null_filter"] = inspect.Parameter(
449+
name=f"{field_name}_not_null_filter",
450+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
451+
default=Dependency(skip_validation=True),
452+
annotation=NotNullFilter | None,
453+
)
454+
annotations[f"{field_name}_not_null_filter"] = NotNullFilter | None
455+
394456
def provide_filters(**kwargs: FilterTypes) -> list[FilterTypes]:
395457
"""Aggregate filter dependencies based on configuration.
396458
@@ -438,6 +500,21 @@ def provide_filters(**kwargs: FilterTypes) -> list[FilterTypes]:
438500
filter_ = kwargs.get(f"{field_def.name}_in_filter")
439501
if filter_ is not None:
440502
filters.append(filter_)
503+
504+
if null_fields := config.get("null_fields"):
505+
null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields)
506+
for field_name in null_fields:
507+
filter_ = kwargs.get(f"{field_name}_null_filter")
508+
if filter_ is not None:
509+
filters.append(filter_)
510+
511+
if not_null_fields := config.get("not_null_fields"):
512+
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields)
513+
for field_name in not_null_fields:
514+
filter_ = kwargs.get(f"{field_name}_not_null_filter")
515+
if filter_ is not None:
516+
filters.append(filter_)
517+
441518
return filters
442519

443520
provide_filters.__signature__ = inspect.Signature( # type: ignore

0 commit comments

Comments
 (0)