Skip to content

Commit 7637a49

Browse files
committed
maybe a better test definition
1 parent 90b2ff3 commit 7637a49

File tree

2 files changed

+28
-27
lines changed

2 files changed

+28
-27
lines changed

tests/common_testing.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,33 @@
33
from django.test import TransactionTestCase as DjangoTransactionTestCase
44

55

6-
class OAuth2ProviderTestCase(DjangoTestCase):
7-
"""Place holder to allow overriding behaviors."""
6+
# When there are multiple databases defined, Django tests will not work unless they are
7+
# told which database(s) to work with. The multiple database scenario setup for these
8+
# tests purposefully defines 'default' as an empty database in order to catch any
9+
# assumptions in this package about database names and in particular to ensure there is
10+
# no assumption that 'default' is a valid database.
11+
# For any test that would usually use Django's TestCase or TransactionTestCase using
12+
# the classes defined here is all that is required.
13+
# Any test that uses pytest's django_db need to include a databases parameter using
14+
# test_database_names defined below.
15+
# In test code, anywhere the database is referenced the Django router needs to be used
16+
# exactly like the package's code.
17+
# For instance:
18+
# token_database = router.db_for_write(AccessToken)
19+
# with self.assertNumQueries(1, using=token_database):
20+
# Without the 'using' option, this test fails in the multiple database scenario because
21+
# 'default' is used.
22+
23+
test_database_names = ["alpha", "beta"] if len(settings.DATABASES) > 1 else ["default"]
24+
825

26+
class OAuth2ProviderBase:
27+
databases = test_database_names
928

10-
class OAuth2ProviderTransactionTestCase(DjangoTransactionTestCase):
29+
30+
class OAuth2ProviderTestCase(OAuth2ProviderBase, DjangoTestCase):
1131
"""Place holder to allow overriding behaviors."""
1232

1333

14-
if len(settings.DATABASES) > 1:
15-
# There are multiple databases defined. When this happens Django tests will not
16-
# work unless they are told which database(s) to work with. The multiple
17-
# database scenario setup for these tests purposefully defines 'default' as an
18-
# empty database in order to catch any assumptions in this package about database
19-
# names and in particular to ensure there is no assumption that 'default' is a
20-
# valid database.
21-
# For any test that would usually use Django's TestCase or TransactionTestCase
22-
# using the classes defined here is all that is required.
23-
# Any test that uses pytest's django_db need to base in a databases parameter
24-
# using this definition of test_database_names.
25-
# In test code, anywhere the default database is used the variable
26-
# database_for_oauth2_provider must be used in its place. For instance,
27-
# with self.assertNumQueries(1, using=database_for_oauth2_provider):
28-
# without the using option this fails because default is used.
29-
test_database_names = {name for name in settings.DATABASES if name != "default"}
30-
database_for_oauth2_provider = "alpha"
31-
OAuth2ProviderTestCase.databases = test_database_names
32-
OAuth2ProviderTransactionTestCase.databases = test_database_names
33-
else:
34-
test_database_names = {"default"}
35-
database_for_oauth2_provider = "default"
34+
class OAuth2ProviderTransactionTestCase(OAuth2ProviderBase, DjangoTransactionTestCase):
35+
"""Place holder to allow overriding behaviors."""

tests/test_introspection_view.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44
import pytest
55
from django.contrib.auth import get_user_model
6+
from django.db import router
67
from django.urls import reverse
78
from django.utils import timezone
89

910
from oauth2_provider.models import get_access_token_model, get_application_model
1011

1112
from . import presets
1213
from .common_testing import OAuth2ProviderTestCase as TestCase
13-
from .common_testing import database_for_oauth2_provider
1414
from .utils import get_basic_auth_header
1515

1616

@@ -344,5 +344,6 @@ def test_view_post_invalid_client_creds_plaintext(self):
344344
self.assertEqual(response.status_code, 403)
345345

346346
def test_select_related_in_view_for_less_db_queries(self):
347-
with self.assertNumQueries(1, using=database_for_oauth2_provider):
347+
token_database = router.db_for_write(AccessToken)
348+
with self.assertNumQueries(1, using=token_database):
348349
self.client.post(reverse("oauth2_provider:introspect"))

0 commit comments

Comments
 (0)