Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sqlspec/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@
InCollectionFilter,
LimitOffsetFilter,
NotInCollectionFilter,
NotNullFilter,
NullFilter,
OrderByFilter,
SearchFilter,
StatementFilter,
Expand Down Expand Up @@ -225,6 +227,8 @@
"LimitOffsetFilter",
"MultiLevelCache",
"NotInCollectionFilter",
"NotNullFilter",
"NullFilter",
"OperationProfile",
"OperationType",
"OrderByFilter",
Expand Down
71 changes: 71 additions & 0 deletions sqlspec/core/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
"NotAnyCollectionFilter",
"NotInCollectionFilter",
"NotInSearchFilter",
"NotNullFilter",
"NullFilter",
"OffsetPagination",
"OnBeforeAfterFilter",
"OrderByFilter",
Expand Down Expand Up @@ -728,6 +730,73 @@ def get_cache_key(self) -> tuple[Any, ...]:
return ("SearchFilter", field_names, self.value, self.ignore_case)


class NullFilter(StatementFilter):
"""Filter for IS NULL queries.

Constructs WHERE field_name IS NULL clauses.
"""

__slots__ = ("_field_name",)

def __init__(self, field_name: str) -> None:
self._field_name = field_name

@property
def field_name(self) -> str:
return self._field_name

def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
"""Extract filter parameters.

Returns empty parameters since IS NULL requires no values.
"""
return [], {}

def append_to_statement(self, statement: "SQL") -> "SQL":
"""Apply IS NULL filter to SQL expression."""
col_expr = exp.column(self.field_name)
is_null_condition = exp.Is(this=col_expr, expression=exp.Null())
return statement.where(is_null_condition)

def get_cache_key(self) -> tuple[Any, ...]:
"""Return cache key for this filter configuration."""
return ("NullFilter", self.field_name)


class NotNullFilter(StatementFilter):
"""Filter for IS NOT NULL queries.

Constructs WHERE field_name IS NOT NULL clauses.
"""

__slots__ = ("_field_name",)

def __init__(self, field_name: str) -> None:
self._field_name = field_name

@property
def field_name(self) -> str:
return self._field_name

def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
"""Extract filter parameters.

Returns empty parameters since IS NOT NULL requires no values.
"""
return [], {}

def append_to_statement(self, statement: "SQL") -> "SQL":
"""Apply IS NOT NULL filter to SQL expression."""
col_expr = exp.column(self.field_name)
is_null_condition = exp.Is(this=col_expr, expression=exp.Null())
is_not_null_condition = exp.Not(this=is_null_condition)
return statement.where(is_not_null_condition)

def get_cache_key(self) -> tuple[Any, ...]:
"""Return cache key for this filter configuration."""
return ("NotNullFilter", self.field_name)


class NotInSearchFilter(SearchFilter):
"""Filter for negated text search queries.

Expand Down Expand Up @@ -836,6 +905,8 @@ def apply_filter(statement: "SQL", filter_obj: StatementFilter) -> "SQL":
| NotInSearchFilter
| AnyCollectionFilter[Any]
| NotAnyCollectionFilter[Any]
| NullFilter
| NotNullFilter
)


Expand Down
60 changes: 60 additions & 0 deletions sqlspec/extensions/fastapi/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
InCollectionFilter,
LimitOffsetFilter,
NotInCollectionFilter,
NotNullFilter,
NullFilter,
OrderByFilter,
SearchFilter,
)
Expand Down Expand Up @@ -108,6 +110,10 @@ class FilterConfig(TypedDict):
"""Fields that support not-in collection filtering. Can be single field or set of fields with type info."""
in_fields: NotRequired[FieldNameType | set[FieldNameType]]
"""Fields that support in-collection filtering. Can be single field or set of fields with type info."""
null_fields: NotRequired[str | set[str]]
"""Fields that support IS NULL filtering. Can be single field name or set of field names."""
not_null_fields: NotRequired[str | set[str]]
"""Fields that support IS NOT NULL filtering. Can be single field name or set of field names."""


class DependencyCache(metaclass=SingletonMeta):
Expand Down Expand Up @@ -187,6 +193,8 @@ async def list_users(
"sort_field",
"not_in_fields",
"in_fields",
"null_fields",
"not_null_fields",
}

has_filters = False
Expand Down Expand Up @@ -512,6 +520,58 @@ def provide_in_filter(
)
annotations[param_name] = Annotated["InCollectionFilter[Any] | None", Depends(in_provider)]

if null_fields := config.get("null_fields"):
null_fields = {null_fields} if isinstance(null_fields, str) else null_fields
for field_name in null_fields:

def create_null_filter_provider(fname: str = field_name) -> "Callable[..., NullFilter | None]":
def provide_null_filter(
is_null: Annotated[
bool | None,
Query(alias=camelize(f"{fname}_is_null"), description=f"Filter where {fname} IS NULL"),
] = None,
) -> "NullFilter | None":
return NullFilter(field_name=fname) if is_null else None

return provide_null_filter

null_provider = create_null_filter_provider()
param_name = f"{field_name}_null_filter"
params.append(
inspect.Parameter(
name=param_name,
kind=inspect.Parameter.KEYWORD_ONLY,
annotation=Annotated["NullFilter | None", Depends(null_provider)],
)
)
annotations[param_name] = Annotated["NullFilter | None", Depends(null_provider)]

if not_null_fields := config.get("not_null_fields"):
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else not_null_fields
for field_name in not_null_fields:

def create_not_null_filter_provider(fname: str = field_name) -> "Callable[..., NotNullFilter | None]":
def provide_not_null_filter(
is_not_null: Annotated[
bool | None,
Query(alias=camelize(f"{fname}_is_not_null"), description=f"Filter where {fname} IS NOT NULL"),
] = None,
) -> "NotNullFilter | None":
return NotNullFilter(field_name=fname) if is_not_null else None

return provide_not_null_filter

not_null_provider = create_not_null_filter_provider()
param_name = f"{field_name}_not_null_filter"
params.append(
inspect.Parameter(
name=param_name,
kind=inspect.Parameter.KEYWORD_ONLY,
annotation=Annotated["NotNullFilter | None", Depends(not_null_provider)],
)
)
annotations[param_name] = Annotated["NotNullFilter | None", Depends(not_null_provider)]

_aggregate_filter_function.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=params, return_annotation=Annotated["list[FilterTypes]", _aggregate_filter_function]
)
Expand Down
81 changes: 79 additions & 2 deletions sqlspec/extensions/litestar/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
InCollectionFilter,
LimitOffsetFilter,
NotInCollectionFilter,
NotNullFilter,
NullFilter,
OrderByFilter,
SearchFilter,
)
Expand Down Expand Up @@ -91,6 +93,10 @@ class FilterConfig(TypedDict):
updated_at: NotRequired[bool]
not_in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]]
in_fields: NotRequired[FieldNameType | set[FieldNameType] | list[str | FieldNameType]]
null_fields: NotRequired[str | set[str] | list[str]]
"""Fields that support IS NULL filtering."""
not_null_fields: NotRequired[str | set[str] | list[str]]
"""Fields that support IS NOT NULL filtering."""


class DependencyCache(metaclass=SingletonMeta):
Expand Down Expand Up @@ -152,7 +158,7 @@ def _make_hashable(value: Any) -> HashableType:
return str(value)


def _create_statement_filters(
def _create_statement_filters( # noqa: C901
config: FilterConfig, dep_defaults: DependencyDefaults = DEPENDENCY_DEFAULTS
) -> dict[str, Provide]:
"""Create filter dependencies based on configuration.
Expand Down Expand Up @@ -294,6 +300,40 @@ def provide_in_filter( # pyright: ignore
provider = create_in_filter_provider(field_def) # type: ignore
filters[f"{field_def.name}_in_filter"] = Provide(provider, sync_to_thread=False) # pyright: ignore

if null_fields := config.get("null_fields"):
null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields)

for field_name in null_fields:

def create_null_filter_provider(fname: str) -> Callable[..., NullFilter | None]:
def provide_null_filter(
is_null: bool | None = Parameter(query=camelize(f"{fname}_is_null"), default=None, required=False),
) -> NullFilter | None:
return NullFilter(field_name=fname) if is_null else None

return provide_null_filter

null_provider = create_null_filter_provider(field_name)
filters[f"{field_name}_null_filter"] = Provide(null_provider, sync_to_thread=False)

if not_null_fields := config.get("not_null_fields"):
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields)

for field_name in not_null_fields:

def create_not_null_filter_provider(fname: str) -> Callable[..., NotNullFilter | None]:
def provide_not_null_filter(
is_not_null: bool | None = Parameter(
query=camelize(f"{fname}_is_not_null"), default=None, required=False
),
) -> NotNullFilter | None:
return NotNullFilter(field_name=fname) if is_not_null else None

return provide_not_null_filter

not_null_provider = create_not_null_filter_provider(field_name)
filters[f"{field_name}_not_null_filter"] = Provide(not_null_provider, sync_to_thread=False)

if filters:
filters[dep_defaults.FILTERS_DEPENDENCY_KEY] = Provide(
_create_filter_aggregate_function(config), sync_to_thread=False
Expand All @@ -302,7 +342,7 @@ def provide_in_filter( # pyright: ignore
return filters


def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., list[FilterTypes]]:
def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., list[FilterTypes]]: # noqa: C901
"""Create filter aggregation function based on configuration.

Args:
Expand Down Expand Up @@ -391,6 +431,28 @@ def _create_filter_aggregate_function(config: FilterConfig) -> Callable[..., lis
)
annotations[f"{field_def.name}_in_filter"] = InCollectionFilter[field_def.type_hint] # type: ignore

if null_fields := config.get("null_fields"):
null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields)
for field_name in null_fields:
parameters[f"{field_name}_null_filter"] = inspect.Parameter(
name=f"{field_name}_null_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=NullFilter | None,
)
annotations[f"{field_name}_null_filter"] = NullFilter | None

if not_null_fields := config.get("not_null_fields"):
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields)
for field_name in not_null_fields:
parameters[f"{field_name}_not_null_filter"] = inspect.Parameter(
name=f"{field_name}_not_null_filter",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=Dependency(skip_validation=True),
annotation=NotNullFilter | None,
)
annotations[f"{field_name}_not_null_filter"] = NotNullFilter | None

def provide_filters(**kwargs: FilterTypes) -> list[FilterTypes]:
"""Aggregate filter dependencies based on configuration.

Expand Down Expand Up @@ -438,6 +500,21 @@ def provide_filters(**kwargs: FilterTypes) -> list[FilterTypes]:
filter_ = kwargs.get(f"{field_def.name}_in_filter")
if filter_ is not None:
filters.append(filter_)

if null_fields := config.get("null_fields"):
null_fields = {null_fields} if isinstance(null_fields, str) else set(null_fields)
for field_name in null_fields:
filter_ = kwargs.get(f"{field_name}_null_filter")
if filter_ is not None:
filters.append(filter_)

if not_null_fields := config.get("not_null_fields"):
not_null_fields = {not_null_fields} if isinstance(not_null_fields, str) else set(not_null_fields)
for field_name in not_null_fields:
filter_ = kwargs.get(f"{field_name}_not_null_filter")
if filter_ is not None:
filters.append(filter_)

return filters

provide_filters.__signature__ = inspect.Signature( # type: ignore
Expand Down
Loading
Loading