Skip to content

Commit 5921788

Browse files
committed
Improve multiple database support.
The token models might not be stored in the default database. There might not _be_ a default database. Intead, the code now relies on Django's routers to determine the actual database to use when creating transactions. This required moving from decorators to context managers for those transactions. To test the multiple database scenario a new settings file as added which derives from settings.py and then defines different databases and the routers needed to access them. The commit is larger than might be expected because when there are multiple databases the Django tests have to be told which databases to work on. Rather than copying the various test cases or making multiple database specific ones the decision was made to add wrappers around the standard Django TestCase classes and programmatically define the databases for them. This enables all of the same test code to work for both the one database and the multi database scenarios with minimal maintenance costs. A tox environment that uses the multi db settings file has been added to ensure both scenarios are always tested.
1 parent ba75297 commit 5921788

30 files changed

+230
-75
lines changed

oauth2_provider/models.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import logging
22
import time
33
import uuid
4+
from contextlib import suppress
45
from datetime import timedelta
56
from urllib.parse import parse_qsl, urlparse
67

78
from django.apps import apps
89
from django.conf import settings
910
from django.contrib.auth.hashers import identify_hasher, make_password
1011
from django.core.exceptions import ImproperlyConfigured
11-
from django.db import models, transaction
12+
from django.db import models, router, transaction
1213
from django.urls import reverse
1314
from django.utils import timezone
1415
from django.utils.translation import gettext_lazy as _
@@ -501,16 +502,27 @@ def revoke(self):
501502
"""
502503
access_token_model = get_access_token_model()
503504
refresh_token_model = get_refresh_token_model()
504-
with transaction.atomic():
505+
506+
access_token_database = router.db_for_write(access_token_model)
507+
refresh_token_database = router.db_for_write(refresh_token_model)
508+
509+
# This is highly unlikely, but let's warn people just in case it does.
510+
if access_token_database != refresh_token_database:
511+
logger.warning(
512+
"access token and refresh token are in separate databases but a transaction"
513+
" is only used for the access token"
514+
)
515+
516+
# Use the access_token_database instead of making the assumption it is in 'default'.
517+
with transaction.atomic(using=access_token_database):
505518
token = refresh_token_model.objects.select_for_update().filter(pk=self.pk, revoked__isnull=True)
506519
if not token:
507520
return
508521
self = list(token)[0]
509522

510-
try:
523+
with suppress(access_token_model.DoesNotExist):
511524
access_token_model.objects.get(id=self.access_token_id).revoke()
512-
except access_token_model.DoesNotExist:
513-
pass
525+
514526
self.access_token = None
515527
self.revoked = timezone.now()
516528
self.save()

oauth2_provider/oauth2_validators.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from django.contrib.auth import authenticate, get_user_model
1515
from django.contrib.auth.hashers import check_password, identify_hasher
1616
from django.core.exceptions import ObjectDoesNotExist
17-
from django.db import transaction
17+
from django.db import router, transaction
1818
from django.db.models import Q
1919
from django.http import HttpRequest
2020
from django.utils import dateformat, timezone
@@ -557,8 +557,12 @@ def rotate_refresh_token(self, request):
557557
"""
558558
return oauth2_settings.ROTATE_REFRESH_TOKEN
559559

560-
@transaction.atomic
561560
def save_bearer_token(self, token, request, *args, **kwargs):
561+
# Use the AccessToken's database instead of making the assumption it is in 'default'.
562+
with transaction.atomic(using=router.db_for_write(AccessToken)):
563+
return self._save_bearer_token_internals(token, request, *args, **kwargs)
564+
565+
def _save_bearer_token_internals(self, token, request, *args, **kwargs):
562566
"""
563567
Save access and refresh token, If refresh token is issued, remove or
564568
reuse old refresh token as in rfc:`6`
@@ -770,7 +774,6 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs
770774
request.refresh_token_instance = rt
771775
return rt.application == client
772776

773-
@transaction.atomic
774777
def _save_id_token(self, jti, request, expires, *args, **kwargs):
775778
scopes = request.scope or " ".join(request.scopes)
776779

@@ -871,7 +874,9 @@ def finalize_id_token(self, id_token, token, token_handler, request):
871874
claims=json.dumps(id_token, default=str),
872875
)
873876
jwt_token.make_signed_token(request.client.jwk_key)
874-
id_token = self._save_id_token(id_token["jti"], request, expiration_time)
877+
# Use the IDToken's database instead of making the assumption it is in 'default'.
878+
with transaction.atomic(using=router.db_for_write(IDToken)):
879+
id_token = self._save_id_token(id_token["jti"], request, expiration_time)
875880
# this is needed by django rest framework
876881
request.access_token = id_token
877882
request.id_token = id_token

tests/common_testing.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from django.conf import settings
2+
from django.test import TestCase as DjangoTestCase
3+
from django.test import TransactionTestCase as DjangoTransactionTestCase
4+
5+
6+
class OAuth2ProviderTestCase(DjangoTestCase):
7+
"""Place holder to allow overriding behaviors."""
8+
9+
10+
class OAuth2ProviderTransactionTestCase(DjangoTransactionTestCase):
11+
"""Place holder to allow overriding behaviors."""
12+
13+
14+
if len(settings.DATABASES) > 1:
15+
# There are multiple databases defined. When this happens Django tests will not
16+
# work unless they are told which database(s) to work with. The multiple
17+
# database scenario setup for these tests purposefully defines 'default' as an
18+
# empty database in order to catch any assumptions in this package about database
19+
# names and in particular to ensure there is no assumption that 'default' is a
20+
# valid database.
21+
# For any test that would usually use Django's TestCase or TransactionTestCase
22+
# using the classes defined here is all that is required.
23+
# Any test that uses pytest's django_db need to base in a databases parameter
24+
# using this definition of test_database_names.
25+
# In test code, anywhere the default database is used the variable
26+
# database_for_oauth2_provider must be used in its place. For instance,
27+
# with self.assertNumQueries(1, using=database_for_oauth2_provider):
28+
# without the using option this fails because default is used.
29+
test_database_names = {name for name in settings.DATABASES if name != "default"}
30+
database_for_oauth2_provider = "alpha"
31+
OAuth2ProviderTestCase.databases = test_database_names
32+
OAuth2ProviderTransactionTestCase.databases = test_database_names
33+
else:
34+
test_database_names = {"default"}
35+
database_for_oauth2_provider = "default"

tests/db_router.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
apps_in_beta = {"some_other_app", "this_one_too"}
2+
3+
# These are bare minimum routers to fake the scenario where there is actually a
4+
# decision around where an application's models might live.
5+
# alpha is where the core Django models are stored including user. To keep things
6+
# simple this is where the oauth2 provider models are stored as well because they
7+
# have a foreign key to User.
8+
9+
10+
class AlphaRouter:
11+
def db_for_read(self, model, **hints):
12+
if model._meta.app_label not in apps_in_beta:
13+
return "alpha"
14+
return None
15+
16+
def db_for_write(self, model, **hints):
17+
if model._meta.app_label not in apps_in_beta:
18+
return "alpha"
19+
return None
20+
21+
def allow_relation(self, obj1, obj2, **hints):
22+
if obj1._state.db == "alpha" and obj2._state.db == "alpha":
23+
return True
24+
return None
25+
26+
def allow_migrate(self, db, app_label, model_name=None, **hints):
27+
if app_label not in apps_in_beta:
28+
return db == "alpha"
29+
return None
30+
31+
32+
class BetaRouter:
33+
def db_for_read(self, model, **hints):
34+
if model._meta.app_label in apps_in_beta:
35+
return "beta"
36+
return None
37+
38+
def db_for_write(self, model, **hints):
39+
if model._meta.app_label in apps_in_beta:
40+
return "beta"
41+
return None
42+
43+
def allow_relation(self, obj1, obj2, **hints):
44+
if obj1._state.db == "beta" and obj2._state.db == "beta":
45+
return True
46+
return None
47+
48+
def allow_migrate(self, db, app_label, model_name=None, **hints):
49+
if app_label in apps_in_beta:
50+
return db == "beta"
51+
return None

tests/multi_db_settings.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Import the test settings and then override DATABASES.
2+
3+
from .settings import * # noqa: F401, F403
4+
5+
6+
DATABASES = {
7+
"alpha": {
8+
"ENGINE": "django.db.backends.sqlite3",
9+
"NAME": ":memory:",
10+
},
11+
"beta": {
12+
"ENGINE": "django.db.backends.sqlite3",
13+
"NAME": ":memory:",
14+
},
15+
# As https://docs.djangoproject.com/en/4.2/topics/db/multi-db/#defining-your-databases
16+
# indicates, it is ok to have no default database.
17+
"default": {},
18+
}
19+
DATABASE_ROUTERS = ["tests.db_router.AlphaRouter", "tests.db_router.BetaRouter"]

tests/test_application_views.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import pytest
22
from django.contrib.auth import get_user_model
3-
from django.test import TestCase
43
from django.urls import reverse
54

65
from oauth2_provider.models import get_application_model
76
from oauth2_provider.views.application import ApplicationRegistration
87

8+
from .common_testing import OAuth2ProviderTestCase as TestCase
99
from .models import SampleApplication
1010

1111

tests/test_auth_backends.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
from django.contrib.auth.models import AnonymousUser
66
from django.core.exceptions import SuspiciousOperation
77
from django.http import HttpResponse
8-
from django.test import RequestFactory, TestCase
8+
from django.test import RequestFactory
99
from django.test.utils import modify_settings, override_settings
1010
from django.utils.timezone import now, timedelta
1111

1212
from oauth2_provider.backends import OAuth2Backend
1313
from oauth2_provider.middleware import OAuth2ExtraTokenMiddleware, OAuth2TokenMiddleware
1414
from oauth2_provider.models import get_access_token_model, get_application_model
1515

16+
from .common_testing import OAuth2ProviderTestCase as TestCase
17+
1618

1719
UserModel = get_user_model()
1820
ApplicationModel = get_application_model()

tests/test_authorization_code.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88
from django.conf import settings
99
from django.contrib.auth import get_user_model
10-
from django.test import RequestFactory, TestCase
10+
from django.test import RequestFactory
1111
from django.urls import reverse
1212
from django.utils import timezone
1313
from django.utils.crypto import get_random_string
@@ -23,6 +23,7 @@
2323
from oauth2_provider.views import ProtectedResourceView
2424

2525
from . import presets
26+
from .common_testing import OAuth2ProviderTestCase as TestCase
2627
from .utils import get_basic_auth_header
2728

2829

tests/test_client_credential.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
from django.contrib.auth import get_user_model
66
from django.core.exceptions import SuspiciousOperation
7-
from django.test import RequestFactory, TestCase
7+
from django.test import RequestFactory
88
from django.urls import reverse
99
from django.views.generic import View
1010
from oauthlib.oauth2 import BackendApplicationServer
@@ -16,6 +16,7 @@
1616
from oauth2_provider.views.mixins import OAuthLibMixin
1717

1818
from . import presets
19+
from .common_testing import OAuth2ProviderTestCase as TestCase
1920
from .utils import get_basic_auth_header
2021

2122

tests/test_commands.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from django.contrib.auth.hashers import check_password
66
from django.core.management import call_command
77
from django.core.management.base import CommandError
8-
from django.test import TestCase
98

109
from oauth2_provider.models import get_application_model
1110

1211
from . import presets
12+
from .common_testing import OAuth2ProviderTestCase as TestCase
1313

1414

1515
Application = get_application_model()

0 commit comments

Comments
 (0)