Skip to content

Commit 7d81829

Browse files
Merge pull request #510 from bast0006/bast0006-new-infraction-filters
Add new infraction filters for the infraction rescheduler
2 parents 03c787b + b076394 commit 7d81829

File tree

2 files changed

+241
-6
lines changed

2 files changed

+241
-6
lines changed

pydis_site/apps/api/tests/test_infractions.py

Lines changed: 168 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
from datetime import datetime as dt, timedelta, timezone
23
from unittest.mock import patch
34
from urllib.parse import quote
@@ -16,7 +17,7 @@ def setUp(self):
1617
self.client.force_authenticate(user=None)
1718

1819
def test_detail_lookup_returns_401(self):
19-
url = reverse('bot:infraction-detail', args=(5,), host='api')
20+
url = reverse('bot:infraction-detail', args=(6,), host='api')
2021
response = self.client.get(url)
2122

2223
self.assertEqual(response.status_code, 401)
@@ -34,7 +35,7 @@ def test_create_returns_401(self):
3435
self.assertEqual(response.status_code, 401)
3536

3637
def test_partial_update_returns_401(self):
37-
url = reverse('bot:infraction-detail', args=(5,), host='api')
38+
url = reverse('bot:infraction-detail', args=(6,), host='api')
3839
response = self.client.patch(url, data={'reason': 'Have a nice day.'})
3940

4041
self.assertEqual(response.status_code, 401)
@@ -44,7 +45,7 @@ class InfractionTests(APISubdomainTestCase):
4445
@classmethod
4546
def setUpTestData(cls):
4647
cls.user = User.objects.create(
47-
id=5,
48+
id=6,
4849
name='james',
4950
discriminator=1,
5051
)
@@ -64,6 +65,30 @@ def setUpTestData(cls):
6465
reason='James is an ass, and we won\'t be working with him again.',
6566
active=False
6667
)
68+
cls.mute_permanent = Infraction.objects.create(
69+
user_id=cls.user.id,
70+
actor_id=cls.user.id,
71+
type='mute',
72+
reason='He has a filthy mouth and I am his soap.',
73+
active=True,
74+
expires_at=None
75+
)
76+
cls.superstar_expires_soon = Infraction.objects.create(
77+
user_id=cls.user.id,
78+
actor_id=cls.user.id,
79+
type='superstar',
80+
reason='This one doesn\'t matter anymore.',
81+
active=True,
82+
expires_at=datetime.datetime.utcnow() + datetime.timedelta(hours=5)
83+
)
84+
cls.voiceban_expires_later = Infraction.objects.create(
85+
user_id=cls.user.id,
86+
actor_id=cls.user.id,
87+
type='voice_ban',
88+
reason='Jet engine mic',
89+
active=True,
90+
expires_at=datetime.datetime.utcnow() + datetime.timedelta(days=5)
91+
)
6792

6893
def test_list_all(self):
6994
"""Tests the list-view, which should be ordered by inserted_at (newest first)."""
@@ -73,9 +98,12 @@ def test_list_all(self):
7398
self.assertEqual(response.status_code, 200)
7499
infractions = response.json()
75100

76-
self.assertEqual(len(infractions), 2)
77-
self.assertEqual(infractions[0]['id'], self.ban_inactive.id)
78-
self.assertEqual(infractions[1]['id'], self.ban_hidden.id)
101+
self.assertEqual(len(infractions), 5)
102+
self.assertEqual(infractions[0]['id'], self.voiceban_expires_later.id)
103+
self.assertEqual(infractions[1]['id'], self.superstar_expires_soon.id)
104+
self.assertEqual(infractions[2]['id'], self.mute_permanent.id)
105+
self.assertEqual(infractions[3]['id'], self.ban_inactive.id)
106+
self.assertEqual(infractions[4]['id'], self.ban_hidden.id)
79107

80108
def test_filter_search(self):
81109
url = reverse('bot:infraction-list', host='api')
@@ -98,6 +126,140 @@ def test_filter_field(self):
98126
self.assertEqual(len(infractions), 1)
99127
self.assertEqual(infractions[0]['id'], self.ban_hidden.id)
100128

129+
def test_filter_permanent_false(self):
130+
url = reverse('bot:infraction-list', host='api')
131+
response = self.client.get(f'{url}?type=mute&permanent=false')
132+
133+
self.assertEqual(response.status_code, 200)
134+
infractions = response.json()
135+
136+
self.assertEqual(len(infractions), 0)
137+
138+
def test_filter_permanent_true(self):
139+
url = reverse('bot:infraction-list', host='api')
140+
response = self.client.get(f'{url}?type=mute&permanent=true')
141+
142+
self.assertEqual(response.status_code, 200)
143+
infractions = response.json()
144+
145+
self.assertEqual(infractions[0]['id'], self.mute_permanent.id)
146+
147+
def test_filter_after(self):
148+
url = reverse('bot:infraction-list', host='api')
149+
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
150+
response = self.client.get(f'{url}?type=superstar&expires_after={target_time.isoformat()}')
151+
152+
self.assertEqual(response.status_code, 200)
153+
infractions = response.json()
154+
self.assertEqual(len(infractions), 0)
155+
156+
def test_filter_before(self):
157+
url = reverse('bot:infraction-list', host='api')
158+
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
159+
response = self.client.get(f'{url}?type=superstar&expires_before={target_time.isoformat()}')
160+
161+
self.assertEqual(response.status_code, 200)
162+
infractions = response.json()
163+
self.assertEqual(len(infractions), 1)
164+
self.assertEqual(infractions[0]['id'], self.superstar_expires_soon.id)
165+
166+
def test_filter_after_invalid(self):
167+
url = reverse('bot:infraction-list', host='api')
168+
response = self.client.get(f'{url}?expires_after=gibberish')
169+
170+
self.assertEqual(response.status_code, 400)
171+
self.assertEqual(list(response.json())[0], "expires_after")
172+
173+
def test_filter_before_invalid(self):
174+
url = reverse('bot:infraction-list', host='api')
175+
response = self.client.get(f'{url}?expires_before=000000000')
176+
177+
self.assertEqual(response.status_code, 400)
178+
self.assertEqual(list(response.json())[0], "expires_before")
179+
180+
def test_after_before_before(self):
181+
url = reverse('bot:infraction-list', host='api')
182+
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=4)
183+
target_time_late = datetime.datetime.utcnow() + datetime.timedelta(hours=6)
184+
response = self.client.get(
185+
f'{url}?expires_before={target_time_late.isoformat()}'
186+
f'&expires_after={target_time.isoformat()}'
187+
)
188+
189+
self.assertEqual(response.status_code, 200)
190+
self.assertEqual(len(response.json()), 1)
191+
self.assertEqual(response.json()[0]["id"], self.superstar_expires_soon.id)
192+
193+
def test_after_after_before_invalid(self):
194+
url = reverse('bot:infraction-list', host='api')
195+
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
196+
target_time_late = datetime.datetime.utcnow() + datetime.timedelta(hours=9)
197+
response = self.client.get(
198+
f'{url}?expires_before={target_time.isoformat()}'
199+
f'&expires_after={target_time_late.isoformat()}'
200+
)
201+
202+
self.assertEqual(response.status_code, 400)
203+
errors = list(response.json())
204+
self.assertIn("expires_before", errors)
205+
self.assertIn("expires_after", errors)
206+
207+
def test_permanent_after_invalid(self):
208+
url = reverse('bot:infraction-list', host='api')
209+
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
210+
response = self.client.get(f'{url}?permanent=true&expires_after={target_time.isoformat()}')
211+
212+
self.assertEqual(response.status_code, 400)
213+
errors = list(response.json())
214+
self.assertEqual("permanent", errors[0])
215+
216+
def test_permanent_before_invalid(self):
217+
url = reverse('bot:infraction-list', host='api')
218+
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
219+
response = self.client.get(f'{url}?permanent=true&expires_before={target_time.isoformat()}')
220+
221+
self.assertEqual(response.status_code, 400)
222+
errors = list(response.json())
223+
self.assertEqual("permanent", errors[0])
224+
225+
def test_nonpermanent_before(self):
226+
url = reverse('bot:infraction-list', host='api')
227+
target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=6)
228+
response = self.client.get(
229+
f'{url}?permanent=false&expires_before={target_time.isoformat()}'
230+
)
231+
232+
self.assertEqual(response.status_code, 200)
233+
self.assertEqual(len(response.json()), 1)
234+
self.assertEqual(response.json()[0]["id"], self.superstar_expires_soon.id)
235+
236+
def test_filter_manytypes(self):
237+
url = reverse('bot:infraction-list', host='api')
238+
response = self.client.get(f'{url}?types=mute,ban')
239+
240+
self.assertEqual(response.status_code, 200)
241+
infractions = response.json()
242+
self.assertEqual(len(infractions), 3)
243+
244+
def test_types_type_invalid(self):
245+
url = reverse('bot:infraction-list', host='api')
246+
response = self.client.get(f'{url}?types=mute,ban&type=superstar')
247+
248+
self.assertEqual(response.status_code, 400)
249+
errors = list(response.json())
250+
self.assertEqual("types", errors[0])
251+
252+
def test_sort_expiresby(self):
253+
url = reverse('bot:infraction-list', host='api')
254+
response = self.client.get(f'{url}?ordering=expires_at&permanent=false')
255+
self.assertEqual(response.status_code, 200)
256+
infractions = response.json()
257+
258+
self.assertEqual(len(infractions), 3)
259+
self.assertEqual(infractions[0]['id'], self.superstar_expires_soon.id)
260+
self.assertEqual(infractions[1]['id'], self.voiceban_expires_later.id)
261+
self.assertEqual(infractions[2]['id'], self.ban_hidden.id)
262+
101263
def test_returns_empty_for_no_match(self):
102264
url = reverse('bot:infraction-list', host='api')
103265
response = self.client.get(f'{url}?type=ban&search=poop')

pydis_site/apps/api/viewsets/bot/infraction.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from datetime import datetime
2+
3+
from django.db.models import QuerySet
14
from django.http.request import HttpRequest
25
from django_filters.rest_framework import DjangoFilterBackend
36
from rest_framework.decorators import action
@@ -43,10 +46,17 @@ class InfractionViewSet(
4346
- **offset** `int`: the initial index from which to return the results (default 0)
4447
- **search** `str`: regular expression applied to the infraction's reason
4548
- **type** `str`: the type of the infraction
49+
- **types** `str`: comma separated sequence of types to filter for
4650
- **user__id** `int`: snowflake of the user to which the infraction was applied
4751
- **ordering** `str`: comma-separated sequence of fields to order the returned results
52+
- **permanent** `bool`: whether or not to retrieve permanent infractions (default True)
53+
- **expires_after** `isodatetime`: the earliest expires_at time to return infractions for
54+
- **expires_before** `isodatetime`: the latest expires_at time to return infractions for
4855
4956
Invalid query parameters are ignored.
57+
Only one of `type` and `types` may be provided. If both `expires_before` and `expires_after`
58+
are provided, `expires_after` must come after `expires_before`.
59+
If `permanent` is provided and true, `expires_before` and `expires_after` must not be provided.
5060
5161
#### Response format
5262
Response is paginated but the result is returned without any pagination metadata.
@@ -156,6 +166,69 @@ def partial_update(self, request: HttpRequest, *_args, **_kwargs) -> Response:
156166

157167
return Response(serializer.data)
158168

169+
def get_queryset(self) -> QuerySet:
170+
"""
171+
Called to fetch the initial queryset, used to implement some of the more complex filters.
172+
173+
This provides the `permanent` and the `expires_gte` and `expires_lte` options.
174+
"""
175+
filter_permanent = self.request.query_params.get('permanent')
176+
additional_filters = {}
177+
if filter_permanent is not None:
178+
additional_filters['expires_at__isnull'] = filter_permanent.lower() == 'true'
179+
180+
filter_expires_after = self.request.query_params.get('expires_after')
181+
if filter_expires_after:
182+
try:
183+
additional_filters['expires_at__gte'] = datetime.fromisoformat(
184+
filter_expires_after
185+
)
186+
except ValueError:
187+
raise ValidationError({'expires_after': ['failed to convert to datetime']})
188+
189+
filter_expires_before = self.request.query_params.get('expires_before')
190+
if filter_expires_before:
191+
try:
192+
additional_filters['expires_at__lte'] = datetime.fromisoformat(
193+
filter_expires_before
194+
)
195+
except ValueError:
196+
raise ValidationError({'expires_before': ['failed to convert to datetime']})
197+
198+
if 'expires_at__lte' in additional_filters and 'expires_at__gte' in additional_filters:
199+
if additional_filters['expires_at__gte'] > additional_filters['expires_at__lte']:
200+
raise ValidationError({
201+
'expires_before': ['cannot be after expires_after'],
202+
'expires_after': ['cannot be before expires_before'],
203+
})
204+
205+
if (
206+
('expires_at__lte' in additional_filters or 'expires_at__gte' in additional_filters)
207+
and 'expires_at__isnull' in additional_filters
208+
and additional_filters['expires_at__isnull']
209+
):
210+
raise ValidationError({
211+
'permanent': [
212+
'cannot filter for permanent infractions at the'
213+
' same time as expires_at or expires_before',
214+
]
215+
})
216+
217+
if filter_expires_before:
218+
# Filter out permanent infractions specifically if we want ones that will expire
219+
# before a given date
220+
additional_filters['expires_at__isnull'] = False
221+
222+
filter_types = self.request.query_params.get('types')
223+
if filter_types:
224+
if self.request.query_params.get('type'):
225+
raise ValidationError({
226+
'types': ['you must provide only one of "type" or "types"'],
227+
})
228+
additional_filters['type__in'] = [i.strip() for i in filter_types.split(",")]
229+
230+
return self.queryset.filter(**additional_filters)
231+
159232
@action(url_path='expanded', detail=False)
160233
def list_expanded(self, *args, **kwargs) -> Response:
161234
"""

0 commit comments

Comments
 (0)