diff --git a/tests/test_models.py b/tests/test_models.py index 9ce1e5eb7..15f89856b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -298,9 +298,11 @@ def setUp(self): super().setUp() # Insert many tokens, both expired and not, and grants. self.num_tokens = 100 - now = timezone.now() - earlier = now - timedelta(seconds=100) - later = now + timedelta(seconds=100) + self.delta_secs = 1000 + self.now = timezone.now() + self.earlier = self.now - timedelta(seconds=self.delta_secs) + self.later = self.now + timedelta(seconds=self.delta_secs) + app = Application.objects.create( name="test_app", redirect_uris="http://localhost http://example.com http://example.org", @@ -309,58 +311,54 @@ def setUp(self): authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, ) # make 200 access tokens, half current and half expired. - expired_access_tokens = AccessToken.objects.bulk_create( - AccessToken(token="expired AccessToken {}".format(i), expires=earlier) + expired_access_tokens = [ + AccessToken(token="expired AccessToken {}".format(i), expires=self.earlier) for i in range(self.num_tokens) - ) - current_access_tokens = AccessToken.objects.bulk_create( - AccessToken(token=f"current AccessToken {i}", expires=later) for i in range(self.num_tokens) - ) + ] + for a in expired_access_tokens: + a.save() + + current_access_tokens = [ + AccessToken(token=f"current AccessToken {i}", expires=self.later) for i in range(self.num_tokens) + ] + for a in current_access_tokens: + a.save() + # Give the first half of the access tokens a refresh token, # alternating between current and expired ones. - RefreshToken.objects.bulk_create( + for i in range(0, len(expired_access_tokens) // 2, 2): RefreshToken( token=f"expired AT's refresh token {i}", application=app, - access_token=expired_access_tokens[i].pk, + access_token=expired_access_tokens[i], user=self.user, - ) - for i in range(0, len(expired_access_tokens) // 2, 2) - ) - RefreshToken.objects.bulk_create( + ).save() + + for i in range(1, len(current_access_tokens) // 2, 2): RefreshToken( token=f"current AT's refresh token {i}", application=app, - access_token=current_access_tokens[i].pk, + access_token=current_access_tokens[i], user=self.user, - ) - for i in range(1, len(current_access_tokens) // 2, 2) - ) + ).save() + # Make some grants, half of which are expired. - Grant.objects.bulk_create( + for i in range(self.num_tokens): Grant( user=self.user, code=f"old grant code {i}", application=app, - expires=earlier, + expires=self.earlier, redirect_uri="https://localhost/redirect", - ) - for i in range(self.num_tokens) - ) - Grant.objects.bulk_create( + ).save() + for i in range(self.num_tokens): Grant( user=self.user, code=f"new grant code {i}", application=app, - expires=later, + expires=self.later, redirect_uri="https://localhost/redirect", - ) - for i in range(self.num_tokens) - ) - - def test_clear_expired_tokens(self): - self.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = 60 - assert clear_expired() is None + ).save() def test_clear_expired_tokens_incorect_timetype(self): self.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = "A" @@ -372,19 +370,61 @@ def test_clear_expired_tokens_incorect_timetype(self): def test_clear_expired_tokens_with_tokens(self): self.oauth2_settings.CLEAR_EXPIRED_TOKENS_BATCH_SIZE = 10 self.oauth2_settings.CLEAR_EXPIRED_TOKENS_BATCH_INTERVAL = 0.0 - at_count = AccessToken.objects.count() - assert at_count == 2 * self.num_tokens, f"{2 * self.num_tokens} access tokens should exist." - rt_count = RefreshToken.objects.count() - assert rt_count == self.num_tokens // 2, f"{self.num_tokens // 2} refresh tokens should exist." - gt_count = Grant.objects.count() - assert gt_count == self.num_tokens * 2, f"{self.num_tokens * 2} grants should exist." + self.oauth2_settings.REFRESH_TOKEN_EXPIRE_SECONDS = self.delta_secs // 2 + + # before clear_expired(), confirm setup as expected + initial_at_count = AccessToken.objects.count() + assert initial_at_count == 2 * self.num_tokens, f"{2 * self.num_tokens} access tokens should exist." + initial_expired_at_count = AccessToken.objects.filter(expires__lte=self.now).count() + assert ( + initial_expired_at_count == self.num_tokens + ), f"{self.num_tokens} expired access tokens should exist." + initial_current_at_count = AccessToken.objects.filter(expires__gt=self.now).count() + assert ( + initial_current_at_count == self.num_tokens + ), f"{self.num_tokens} current access tokens should exist." + initial_rt_count = RefreshToken.objects.count() + assert ( + initial_rt_count == self.num_tokens // 2 + ), f"{self.num_tokens // 2} refresh tokens should exist." + initial_rt_expired_at_count = RefreshToken.objects.filter(access_token__expires__lte=self.now).count() + assert ( + initial_rt_expired_at_count == initial_rt_count / 2 + ), "half the refresh tokens should be for expired access tokens." + initial_rt_current_at_count = RefreshToken.objects.filter(access_token__expires__gt=self.now).count() + assert ( + initial_rt_current_at_count == initial_rt_count / 2 + ), "half the refresh tokens should be for current access tokens." + initial_gt_count = Grant.objects.count() + assert initial_gt_count == self.num_tokens * 2, f"{self.num_tokens * 2} grants should exist." + clear_expired() - at_count = AccessToken.objects.count() - assert at_count == self.num_tokens, "Half the access tokens should not have been deleted." - rt_count = RefreshToken.objects.count() - assert rt_count == self.num_tokens // 2, "Half of the refresh tokens should have been deleted." - gt_count = Grant.objects.count() - assert gt_count == self.num_tokens, "Half the grants should have been deleted." + + # after clear_expired(): + remaining_at_count = AccessToken.objects.count() + assert ( + remaining_at_count == initial_at_count // 2 + ), "half the initial access tokens should still exist." + remaining_expired_at_count = AccessToken.objects.filter(expires__lte=self.now).count() + assert remaining_expired_at_count == 0, "no remaining expired access tokens should still exist." + remaining_current_at_count = AccessToken.objects.filter(expires__gt=self.now).count() + assert ( + remaining_current_at_count == initial_current_at_count + ), "all current access tokens should still exist." + remaining_rt_count = RefreshToken.objects.count() + assert remaining_rt_count == initial_rt_count // 2, "half the refresh tokens should still exist." + remaining_rt_expired_at_count = RefreshToken.objects.filter( + access_token__expires__lte=self.now + ).count() + assert remaining_rt_expired_at_count == 0, "no refresh tokens for expired AT's should still exist." + remaining_rt_current_at_count = RefreshToken.objects.filter( + access_token__expires__gt=self.now + ).count() + assert ( + remaining_rt_current_at_count == initial_rt_current_at_count + ), "all the refresh tokens for current access tokens should still exist." + remaining_gt_count = Grant.objects.count() + assert remaining_gt_count == initial_gt_count // 2, "half the remaining grants should still exist." @pytest.mark.django_db