Skip to content

change token to TextField in AbstractAccessToken model #1447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ Dylan Tack
Eduardo Oliveira
Egor Poderiagin
Emanuele Palazzetti
Fazeel Ghafoor
Federico Dolce
Florian Demmer
Frederico Vieira
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]
### Added
* Add migration to include `token_checksum` field in AbstractAccessToken model.
* #1404 Add a new setting `REFRESH_TOKEN_REUSE_PROTECTION`
### Changed
* Update token to TextField from CharField with 255 character limit and SHA-256 checksum in AbstractAccessToken model. Removing the 255 character limit enables supporting JWT tokens with additional claims

* Update middleware, validators, and views to use token checksums instead of token for token retrieval and validation.
### Deprecated
### Removed
* #1425 Remove deprecated `RedirectURIValidator`, `WildcardSet` per #1345; `validate_logout_request` per #1274
Expand Down
4 changes: 3 additions & 1 deletion oauth2_provider/middleware.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import logging

from django.contrib.auth import authenticate
Expand Down Expand Up @@ -55,7 +56,8 @@ def __call__(self, request):
tokenstring = authheader.split()[1]
AccessToken = get_access_token_model()
try:
token = AccessToken.objects.get(token=tokenstring)
token_checksum = hashlib.sha256(tokenstring.encode("utf-8")).hexdigest()
token = AccessToken.objects.get(token_checksum=token_checksum)
request.access_token = token
except AccessToken.DoesNotExist as e:
log.exception(e)
Expand Down
26 changes: 26 additions & 0 deletions oauth2_provider/migrations/0012_add_token_checksum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Generated by Django 5.0.7 on 2024-07-29 23:13

import oauth2_provider.models
from django.db import migrations, models
from oauth2_provider.settings import oauth2_settings

class Migration(migrations.Migration):
dependencies = [
("oauth2_provider", "0011_refreshtoken_token_family"),
migrations.swappable_dependency(oauth2_settings.ACCESS_TOKEN_MODEL),
]

operations = [
migrations.AddField(
model_name="accesstoken",
name="token_checksum",
field=oauth2_provider.models.TokenChecksumField(
blank=True, db_index=True, max_length=64, unique=True
),
),
migrations.AlterField(
model_name="accesstoken",
name="token",
field=models.TextField(),
),
]
15 changes: 13 additions & 2 deletions oauth2_provider/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import logging
import time
import uuid
Expand Down Expand Up @@ -44,6 +45,14 @@ def pre_save(self, model_instance, add):
return super().pre_save(model_instance, add)


class TokenChecksumField(models.CharField):
def pre_save(self, model_instance, add):
token = getattr(model_instance, "token")
checksum = hashlib.sha256(token.encode("utf-8")).hexdigest()
setattr(model_instance, self.attname, checksum)
return super().pre_save(model_instance, add)


class AbstractApplication(models.Model):
"""
An Application instance represents a Client on the Authorization server.
Expand Down Expand Up @@ -379,8 +388,10 @@ class AbstractAccessToken(models.Model):
null=True,
related_name="refreshed_access_token",
)
token = models.CharField(
max_length=255,
token = models.TextField()
token_checksum = TokenChecksumField(
max_length=64,
blank=True,
unique=True,
db_index=True,
)
Expand Down
8 changes: 7 additions & 1 deletion oauth2_provider/oauth2_validators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import binascii
import hashlib
import http.client
import inspect
import json
Expand Down Expand Up @@ -461,7 +462,12 @@ def validate_bearer_token(self, token, scopes, request):
return False

def _load_access_token(self, token):
return AccessToken.objects.select_related("application", "user").filter(token=token).first()
token_checksum = hashlib.sha256(token.encode("utf-8")).hexdigest()
return (
AccessToken.objects.select_related("application", "user")
.filter(token_checksum=token_checksum)
.first()
)

def validate_code(self, client_id, code, client, request, *args, **kwargs):
try:
Expand Down
4 changes: 3 additions & 1 deletion oauth2_provider/views/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import json
import logging
from urllib.parse import parse_qsl, urlencode, urlparse
Expand Down Expand Up @@ -289,7 +290,8 @@ def post(self, request, *args, **kwargs):
if status == 200:
access_token = json.loads(body).get("access_token")
if access_token is not None:
token = get_access_token_model().objects.get(token=access_token)
token_checksum = hashlib.sha256(access_token.encode("utf-8")).hexdigest()
token = get_access_token_model().objects.get(token_checksum=token_checksum)
app_authorized.send(sender=self, request=request, token=token)
response = HttpResponse(content=body, status=status)

Expand Down
6 changes: 5 additions & 1 deletion oauth2_provider/views/introspect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import calendar
import hashlib

from django.core.exceptions import ObjectDoesNotExist
from django.http import JsonResponse
Expand All @@ -24,8 +25,11 @@ class IntrospectTokenView(ClientProtectedScopedResourceView):
@staticmethod
def get_token_response(token_value=None):
try:
token_checksum = hashlib.sha256(token_value.encode("utf-8")).hexdigest()
token = (
get_access_token_model().objects.select_related("user", "application").get(token=token_value)
get_access_token_model()
.objects.select_related("user", "application")
.get(token_checksum=token_checksum)
)
except ObjectDoesNotExist:
return JsonResponse({"active": False}, status=200)
Expand Down
12 changes: 8 additions & 4 deletions tests/migrations/0002_swapped_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,14 @@ class Migration(migrations.Migration):
field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='s_refreshed_access_token', to=settings.OAUTH2_PROVIDER_REFRESH_TOKEN_MODEL),
),
migrations.AddField(
model_name='sampleaccesstoken',
name='token',
field=models.CharField(max_length=255, unique=True),
preserve_default=False,
model_name="sampleaccesstoken",
name="token",
field=models.TextField(),
),
migrations.AddField(
model_name="sampleaccesstoken",
name="token_checksum",
field=models.CharField(max_length=64, unique=True, db_index=True),
),
migrations.AddField(
model_name='sampleaccesstoken',
Expand Down
13 changes: 13 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import secrets
from datetime import timedelta

import pytest
Expand Down Expand Up @@ -310,6 +312,17 @@ def test_expires_can_be_none(self):
self.assertIsNone(access_token.expires)
self.assertTrue(access_token.is_expired())

def test_token_checksum_field(self):
token = secrets.token_urlsafe(32)
access_token = AccessToken.objects.create(
user=self.user,
token=token,
expires=timezone.now() + timedelta(hours=1),
)
expected_checksum = hashlib.sha256(token.encode()).hexdigest()

self.assertEqual(access_token.token_checksum, expected_checksum)


class TestRefreshTokenModel(BaseTestModels):
def test_str(self):
Expand Down
Loading