Skip to content

Honor database assignment from router #1450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ Rodney Richardson
Rustem Saiargaliev
Rustem Saiargaliev
Sandro Rodrigues
Sean 'Shaleh' Perry
Shaheed Haque
Shaun Stanworth
Sayyid Hamid Mahdavi
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Update token to TextField from CharField with 255 character limit and SHA-256 checksum in AbstractAccessToken model. Removing the 255 character limit enables supporting JWT tokens with additional claims
* Update middleware, validators, and views to use token checksums instead of token for token retrieval and validation.
* #1446 use generic models pk instead of id.
* Transactions wrapping writes of the Tokens now rely on Django's database routers to determine the correct
database to use instead of assuming that 'default' is the correct one.

### Deprecated
### Removed
Expand Down
11 changes: 11 additions & 0 deletions docs/advanced_topics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,17 @@ That's all, now Django OAuth Toolkit will use your model wherever an Application
is because of the way Django currently implements swappable models.
See `issue #90 <https://github.com/jazzband/django-oauth-toolkit/issues/90>`_ for details.

Configuring multiple databases
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

There is no requirement that the tokens are stored in the default database or that there is a
default database provided the database routers can determine the correct Token locations. Because the
Tokens have foreign keys to the ``User`` model, you likely want to keep the tokens in the same database
as your User model. It is also important that all of the tokens are stored in the same database.
This could happen for instance if one of the Tokens is locally overridden and stored in a separate database.
The reason for this is transactions will only be made for the database where AccessToken is stored
even when writing to RefreshToken or other tokens.

Multiple Grants
~~~~~~~~~~~~~~~

Expand Down
20 changes: 20 additions & 0 deletions docs/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,26 @@ Open :file:`mycoverage/index.html` in your browser and you can see a coverage su

There's no need to wait for Codecov to complain after you submit your PR.

The tests are generic and written to work with both single database and multiple database configurations. tox will run
tests both ways. You can see the configurations used in tests/settings.py and tests/multi_db_settings.py.

When there are multiple databases defined, Django tests will not work unless they are told which database(s) to work with.
For test writers this means any test must either:
- instead of Django's TestCase or TransactionTestCase use the versions of those
classes defined in tests/common_testing.py
- when using pytest's `django_db` mark, define it like this:
`@pytest.mark.django_db(databases=retrieve_current_databases())`

In test code, anywhere the database is referenced the Django router needs to be used exactly like the package's code.

.. code-block:: python

token_database = router.db_for_write(AccessToken)
with self.assertNumQueries(1, using=token_database):
# call something using the database

Without the 'using' option, this test fails in the multiple database scenario because 'default' will be used instead.

Code conventions matter
-----------------------

Expand Down
4 changes: 4 additions & 0 deletions oauth2_provider/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
class DOTConfig(AppConfig):
name = "oauth2_provider"
verbose_name = "Django OAuth Toolkit"

def ready(self):
# Import checks to ensure they run.
from . import checks # noqa: F401
28 changes: 28 additions & 0 deletions oauth2_provider/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from django.apps import apps
from django.core import checks
from django.db import router

from .settings import oauth2_settings


@checks.register(checks.Tags.database)
def validate_token_configuration(app_configs, **kwargs):
databases = set(
router.db_for_write(apps.get_model(model))
for model in (
oauth2_settings.ACCESS_TOKEN_MODEL,
oauth2_settings.ID_TOKEN_MODEL,
oauth2_settings.REFRESH_TOKEN_MODEL,
)
)

# This is highly unlikely, but let's warn people just in case it does.
# If the tokens were allowed to be in different databases this would require all
# writes to have a transaction around each database. Instead, let's enforce that
# they all live together in one database.
# The tokens are not required to live in the default database provided the Django
# routers know the correct database for them.
if len(databases) > 1:
return [checks.Error("The token models are expected to be stored in the same database.")]

return []
17 changes: 10 additions & 7 deletions oauth2_provider/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import logging
import time
import uuid
from contextlib import suppress
from datetime import timedelta
from urllib.parse import parse_qsl, urlparse

from django.apps import apps
from django.conf import settings
from django.contrib.auth.hashers import identify_hasher, make_password
from django.core.exceptions import ImproperlyConfigured
from django.db import models, transaction
from django.db import models, router, transaction
from django.urls import reverse
from django.utils import timezone
from django.utils.translation import gettext_lazy as _
Expand Down Expand Up @@ -512,17 +513,19 @@ def revoke(self):
Mark this refresh token revoked and revoke related access token
"""
access_token_model = get_access_token_model()
access_token_database = router.db_for_write(access_token_model)
refresh_token_model = get_refresh_token_model()
with transaction.atomic():

# Use the access_token_database instead of making the assumption it is in 'default'.
with transaction.atomic(using=access_token_database):
token = refresh_token_model.objects.select_for_update().filter(pk=self.pk, revoked__isnull=True)
if not token:
return
self = list(token)[0]

try:
access_token_model.objects.get(pk=self.access_token_id).revoke()
except access_token_model.DoesNotExist:
pass
with suppress(access_token_model.DoesNotExist):
access_token_model.objects.get(id=self.access_token_id).revoke()

self.access_token = None
self.revoked = timezone.now()
self.save()
Expand Down Expand Up @@ -655,7 +658,7 @@ def get_access_token_model():


def get_id_token_model():
"""Return the AccessToken model that is active in this project."""
"""Return the IDToken model that is active in this project."""
return apps.get_model(oauth2_settings.ID_TOKEN_MODEL)


Expand Down
25 changes: 19 additions & 6 deletions oauth2_provider/oauth2_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from django.contrib.auth import authenticate, get_user_model
from django.contrib.auth.hashers import check_password, identify_hasher
from django.core.exceptions import ObjectDoesNotExist
from django.db import transaction
from django.db import router, transaction
from django.http import HttpRequest
from django.utils import dateformat, timezone
from django.utils.crypto import constant_time_compare
Expand Down Expand Up @@ -562,11 +562,23 @@ def rotate_refresh_token(self, request):
"""
return oauth2_settings.ROTATE_REFRESH_TOKEN

@transaction.atomic
def save_bearer_token(self, token, request, *args, **kwargs):
"""
Save access and refresh token, If refresh token is issued, remove or
reuse old refresh token as in rfc:`6`
Save access and refresh token.

Override _save_bearer_token and not this function when adding custom logic
for the storing of these token. This allows the transaction logic to be
separate from the token handling.
"""
# Use the AccessToken's database instead of making the assumption it is in 'default'.
with transaction.atomic(using=router.db_for_write(AccessToken)):
return self._save_bearer_token(token, request, *args, **kwargs)

def _save_bearer_token(self, token, request, *args, **kwargs):
"""
Save access and refresh token.

If refresh token is issued, remove or reuse old refresh token as in rfc:`6`.

@see: https://rfc-editor.org/rfc/rfc6749.html#section-6
"""
Expand Down Expand Up @@ -788,7 +800,6 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs

return rt.application == client

@transaction.atomic
def _save_id_token(self, jti, request, expires, *args, **kwargs):
scopes = request.scope or " ".join(request.scopes)

Expand Down Expand Up @@ -889,7 +900,9 @@ def finalize_id_token(self, id_token, token, token_handler, request):
claims=json.dumps(id_token, default=str),
)
jwt_token.make_signed_token(request.client.jwk_key)
id_token = self._save_id_token(id_token["jti"], request, expiration_time)
# Use the IDToken's database instead of making the assumption it is in 'default'.
with transaction.atomic(using=router.db_for_write(IDToken)):
id_token = self._save_id_token(id_token["jti"], request, expiration_time)
# this is needed by django rest framework
request.access_token = id_token
request.id_token = id_token
Expand Down
33 changes: 33 additions & 0 deletions tests/common_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from django.conf import settings
from django.test import TestCase as DjangoTestCase
from django.test import TransactionTestCase as DjangoTransactionTestCase


# The multiple database scenario setup for these tests purposefully defines 'default' as
# an empty database in order to catch any assumptions in this package about database names
# and in particular to ensure there is no assumption that 'default' is a valid database.
#
# When there are multiple databases defined, Django tests will not work unless they are
# told which database(s) to work with.


def retrieve_current_databases():
if len(settings.DATABASES) > 1:
return [name for name in settings.DATABASES if name != "default"]
else:
return ["default"]


class OAuth2ProviderBase:
@classmethod
def setUpClass(cls):
cls.databases = retrieve_current_databases()
super().setUpClass()


class OAuth2ProviderTestCase(OAuth2ProviderBase, DjangoTestCase):
"""Place holder to allow overriding behaviors."""


class OAuth2ProviderTransactionTestCase(OAuth2ProviderBase, DjangoTransactionTestCase):
"""Place holder to allow overriding behaviors."""
76 changes: 76 additions & 0 deletions tests/db_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
apps_in_beta = {"some_other_app", "this_one_too"}

# These are bare minimum routers to fake the scenario where there is actually a
# decision around where an application's models might live.


class AlphaRouter:
# alpha is where the core Django models are stored including user. To keep things
# simple this is where the oauth2 provider models are stored as well because they
# have a foreign key to User.

def db_for_read(self, model, **hints):
if model._meta.app_label not in apps_in_beta:
return "alpha"
return None

def db_for_write(self, model, **hints):
if model._meta.app_label not in apps_in_beta:
return "alpha"
return None

def allow_relation(self, obj1, obj2, **hints):
if obj1._state.db == "alpha" and obj2._state.db == "alpha":
return True
return None

def allow_migrate(self, db, app_label, model_name=None, **hints):
if app_label not in apps_in_beta:
return db == "alpha"
return None


class BetaRouter:
def db_for_read(self, model, **hints):
if model._meta.app_label in apps_in_beta:
return "beta"
return None

def db_for_write(self, model, **hints):
if model._meta.app_label in apps_in_beta:
return "beta"
return None

def allow_relation(self, obj1, obj2, **hints):
if obj1._state.db == "beta" and obj2._state.db == "beta":
return True
return None

def allow_migrate(self, db, app_label, model_name=None, **hints):
if app_label in apps_in_beta:
return db == "beta"


class CrossDatabaseRouter:
# alpha is where the core Django models are stored including user. To keep things
# simple this is where the oauth2 provider models are stored as well because they
# have a foreign key to User.
def db_for_read(self, model, **hints):
if model._meta.model_name == "accesstoken":
return "beta"
return None

def db_for_write(self, model, **hints):
if model._meta.model_name == "accesstoken":
return "beta"
return None

def allow_relation(self, obj1, obj2, **hints):
if obj1._state.db == "beta" and obj2._state.db == "beta":
return True
return None

def allow_migrate(self, db, app_label, model_name=None, **hints):
if model_name == "accesstoken":
return db == "beta"
return None
34 changes: 34 additions & 0 deletions tests/migrations/0007_add_localidtoken.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Generated by Django 3.2.25 on 2024-08-08 22:47

from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import uuid


class Migration(migrations.Migration):

dependencies = [
migrations.swappable_dependency(settings.OAUTH2_PROVIDER_APPLICATION_MODEL),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('tests', '0006_basetestapplication_token_family'),
]

operations = [
migrations.CreateModel(
name='LocalIDToken',
fields=[
('id', models.BigAutoField(primary_key=True, serialize=False)),
('jti', models.UUIDField(default=uuid.uuid4, editable=False, unique=True, verbose_name='JWT Token ID')),
('expires', models.DateTimeField()),
('scope', models.TextField(blank=True)),
('created', models.DateTimeField(auto_now_add=True)),
('updated', models.DateTimeField(auto_now=True)),
('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)),
('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_localidtoken', to=settings.AUTH_USER_MODEL)),
],
options={
'abstract': False,
},
),
]
7 changes: 7 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AbstractAccessToken,
AbstractApplication,
AbstractGrant,
AbstractIDToken,
AbstractRefreshToken,
)
from oauth2_provider.settings import oauth2_settings
Expand Down Expand Up @@ -54,3 +55,9 @@ class SampleRefreshToken(AbstractRefreshToken):

class SampleGrant(AbstractGrant):
custom_field = models.CharField(max_length=255)


class LocalIDToken(AbstractIDToken):
"""Exists to be improperly configured for multiple databases."""

# The other token types will be in 'alpha' database.
19 changes: 19 additions & 0 deletions tests/multi_db_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Import the test settings and then override DATABASES.

from .settings import * # noqa: F401, F403


DATABASES = {
"alpha": {
"ENGINE": "django.db.backends.sqlite3",
"NAME": ":memory:",
},
"beta": {
"ENGINE": "django.db.backends.sqlite3",
"NAME": ":memory:",
},
# As https://docs.djangoproject.com/en/4.2/topics/db/multi-db/#defining-your-databases
# indicates, it is ok to have no default database.
"default": {},
}
DATABASE_ROUTERS = ["tests.db_router.AlphaRouter", "tests.db_router.BetaRouter"]
8 changes: 8 additions & 0 deletions tests/multi_db_settings_invalid_token_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .multi_db_settings import * # noqa: F401, F403


OAUTH2_PROVIDER = {
# The other two tokens will be in alpha. This will cause a failure when the
# app's ready method is called.
"ID_TOKEN_MODEL": "tests.LocalIDToken",
}
Loading
Loading