Skip to content

Commit a8410e7

Browse files
committed
Add exclude filter
1 parent 105623d commit a8410e7

File tree

3 files changed

+104
-22
lines changed

3 files changed

+104
-22
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# 0.7.5
22

33
- Fix a bug preventing select_related on multiple fields from working
4+
- Add django style `exclude` to filter
45

56
# 0.7.4
67

atomdb/sql.py

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,43 @@ def distinct(self, *args):
10731073
distinct_clauses=distinct_clauses, related_clauses=related_clauses
10741074
)
10751075

1076+
def where_clause(self, k: str, v: Any, related_clauses: ListType):
1077+
""" Create a where clause from a django-style parameter.
1078+
This will modify the list of related clauses if a join occurs.
1079+
1080+
Parameters
1081+
----------
1082+
k: str
1083+
The filter key, eg name__startswith
1084+
v: object
1085+
The value
1086+
related_clauses: list
1087+
List of related clauses
1088+
1089+
Returns
1090+
-------
1091+
clause: sqlalchemy.sq.expression
1092+
The filter clause
1093+
1094+
"""
1095+
model = self.proxy.model
1096+
op = "eq"
1097+
if "__" in k:
1098+
parts = k.split("__")
1099+
if parts[-1] in QUERY_OPS:
1100+
op = parts[-1]
1101+
k = "__".join(parts[:-1])
1102+
col = resolve_member_column(model, k, related_clauses)
1103+
1104+
# Support lookups by model
1105+
if isinstance(v, Model):
1106+
v = v.serializer.flatten_object(v, scope={})
1107+
elif op in ("in", "notin"):
1108+
# Flatten lists when using in or notin ops
1109+
v = model.serializer.flatten(v, scope={})
1110+
1111+
return getattr(col, QUERY_OPS[op])(v)
1112+
10761113
def filter(self, *args, **kwargs: DictType[str, Any]):
10771114
"""Filter the query by the given parameters. This accepts sqlalchemy
10781115
filters by arguments and django-style parameters as kwargs.
@@ -1094,30 +1131,53 @@ def filter(self, *args, **kwargs: DictType[str, Any]):
10941131
filter_clauses = self.filter_clauses + list(args)
10951132
related_clauses = self.related_clauses[:]
10961133

1097-
connection_kwarg, restore_kwarg = p.connection_kwarg, p.restore_kwarg
1134+
connection_kwarg = p.connection_kwarg
1135+
restore_kwarg = p.restore_kwarg
1136+
1137+
# Build the filter operations
1138+
for k, v in kwargs.items():
1139+
if k == connection_kwarg or k == restore_kwarg:
1140+
continue
1141+
filter_clauses.append(self.where_clause(k, v, related_clauses))
1142+
1143+
return self.clone(
1144+
connection=kwargs.get(connection_kwarg, self.connection),
1145+
force_restore=kwargs.get(restore_kwarg, self.force_restore),
1146+
filter_clauses=filter_clauses,
1147+
related_clauses=related_clauses,
1148+
)
1149+
1150+
def exclude(self, *args, **kwargs: DictType[str, Any]):
1151+
"""Exclude results matching the given parameters by wrapping each
1152+
clause in a NOT expression. This accepts sqlalchemy filters by
1153+
arguments and django-style parameters as kwargs.
1154+
1155+
Parameters
1156+
----------
1157+
args: List
1158+
List of sqlalchemy filters
1159+
kwargs: Dict[str, object]
1160+
Django style filters to use
1161+
1162+
Returns
1163+
-------
1164+
query: SQLQuerySet
1165+
A clone of this queryset with the excluded filter terms added.
1166+
1167+
"""
1168+
p = self.proxy
1169+
filter_clauses = self.filter_clauses + [sa.not_(it) for it in args]
1170+
related_clauses = self.related_clauses[:]
1171+
1172+
connection_kwarg = p.connection_kwarg
1173+
restore_kwarg = p.restore_kwarg
10981174

10991175
# Build the filter operations
11001176
for k, v in kwargs.items():
11011177
if k == connection_kwarg or k == restore_kwarg:
11021178
continue
1103-
model = p.model
1104-
op = "eq"
1105-
if "__" in k:
1106-
parts = k.split("__")
1107-
if parts[-1] in QUERY_OPS:
1108-
op = parts[-1]
1109-
k = "__".join(parts[:-1])
1110-
col = resolve_member_column(model, k, related_clauses)
1111-
1112-
# Support lookups by model
1113-
if isinstance(v, Model):
1114-
v = v.serializer.flatten_object(v, scope={})
1115-
elif op in ("in", "notin"):
1116-
# Flatten lists when using in or notin ops
1117-
v = model.serializer.flatten(v, scope={})
1118-
1119-
clause = getattr(col, QUERY_OPS[op])(v)
1120-
filter_clauses.append(clause)
1179+
clause = self.where_clause(k, v, related_clauses)
1180+
filter_clauses.append(sa.not_(clause))
11211181

11221182
return self.clone(
11231183
connection=kwargs.get(connection_kwarg, self.connection),

tests/test_sql.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,17 +1036,17 @@ async def test_filters(db):
10361036
await reset_tables(User)
10371037

10381038
user, created = await User.objects.get_or_create(
1039-
name=faker.name(), email=faker.email(), age=21, active=True
1039+
name="Bob", email=faker.email(), age=21, active=True
10401040
)
10411041
assert created
10421042

10431043
user2, created = await User.objects.get_or_create(
1044-
name=faker.name(), email=faker.email(), age=48, active=False, rating=10.0
1044+
name="Tom", email=faker.email(), age=48, active=False, rating=10.0
10451045
)
10461046
assert created
10471047

10481048
# Startswith
1049-
u = await User.objects.get(name__startswith=user.name[0])
1049+
u = await User.objects.get(name__startswith="B")
10501050
assert u.name == user.name
10511051
assert u is user # Now cached
10521052

@@ -1070,6 +1070,9 @@ async def test_filters(db):
10701070
users = await User.objects.filter(age__lt=30)
10711071
assert len(users) == 1 and users[0].age == user.age
10721072

1073+
users = await User.objects.exclude(age=21)
1074+
assert len(users) == 1 and users[0].age == 48
1075+
10731076
# Not supported
10741077
with pytest.raises(ValueError):
10751078
users = await User.objects.filter(age__xor=1)
@@ -1083,6 +1086,24 @@ async def test_filters(db):
10831086
users = await User.objects.filter(does_not_exist=True)
10841087

10851088

1089+
async def test_filter_exclude(db):
1090+
await reset_tables(User)
1091+
# Create second user
1092+
await User.objects.create(name="Bob", email="[email protected]", age=40, active=True)
1093+
await User.objects.create(name="Jack", email="[email protected]", age=30, active=False)
1094+
await User.objects.create(name="Bob", email="[email protected]", age=20, active=False)
1095+
1096+
users = await User.objects.filter(name__startswith="B").exclude(
1097+
email__endswith="other.com")
1098+
assert len(users) == 1 and users[0].email == "[email protected]"
1099+
1100+
users = await User.objects.exclude(active=True, age__lt=25)
1101+
assert len(users) == 1 and users[0].name == "Jack"
1102+
1103+
users = await User.objects.exclude(name="Bob")
1104+
assert len(users) == 1 and users[0].name == "Jack"
1105+
1106+
10861107
async def test_update(db):
10871108
await reset_tables(User)
10881109
# Create second user

0 commit comments

Comments
 (0)