Skip to content

Commit dbddebf

Browse files
committed
No more magic.
1 parent 08f5021 commit dbddebf

File tree

5 files changed

+51
-49
lines changed

5 files changed

+51
-49
lines changed

tests/common_testing.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
# no assumption that 'default' is a valid database.
1111
# For any test that would usually use Django's TestCase or TransactionTestCase using
1212
# the classes defined here is all that is required.
13-
# Any test that uses pytest's django_db need to include a databases parameter using
14-
# test_database_names defined below.
1513
# In test code, anywhere the database is referenced the Django router needs to be used
1614
# exactly like the package's code.
1715
# For instance:
@@ -20,11 +18,19 @@
2018
# Without the 'using' option, this test fails in the multiple database scenario because
2119
# 'default' is used.
2220

23-
test_database_names = ["alpha", "beta"] if len(settings.DATABASES) > 1 else ["default"]
21+
22+
def retrieve_current_databases():
23+
if len(settings.DATABASES) > 1:
24+
return [name for name in settings.DATABASES if name != "default"]
25+
else:
26+
return ["default"]
2427

2528

2629
class OAuth2ProviderBase:
27-
databases = test_database_names
30+
@classmethod
31+
def setUpClass(cls):
32+
cls.databases = retrieve_current_databases()
33+
super().setUpClass()
2834

2935

3036
class OAuth2ProviderTestCase(OAuth2ProviderBase, DjangoTestCase):

tests/test_hybrid.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from . import presets
2424
from .common_testing import OAuth2ProviderTestCase as TestCase
25-
from .common_testing import test_database_names
2625
from .utils import get_basic_auth_header, spy_on
2726

2827

@@ -1320,7 +1319,7 @@ def test_pre_auth_default_scopes(self):
13201319
self.assertEqual(form["client_id"].value(), self.application.client_id)
13211320

13221321

1323-
@pytest.mark.django_db(databases=test_database_names)
1322+
@pytest.mark.django_db
13241323
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
13251324
def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_application, client, oidc_key):
13261325
client.force_login(test_user)
@@ -1369,7 +1368,7 @@ def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_app
13691368
assert claims["nonce"] == "random_nonce_string"
13701369

13711370

1372-
@pytest.mark.django_db(databases=test_database_names)
1371+
@pytest.mark.django_db
13731372
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
13741373
def test_claims_passed_to_code_generation(
13751374
oauth2_settings, test_user, hybrid_application, client, mocker, oidc_key

tests/test_models.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from . import presets
2222
from .common_testing import OAuth2ProviderTestCase as TestCase
23-
from .common_testing import test_database_names
2423

2524

2625
CLEARTEXT_SECRET = "1234567890abcdefghijklmnopqrstuvwxyz"
@@ -467,7 +466,7 @@ def test_clear_expired_tokens_with_tokens(self):
467466
assert remaining_gt_count == initial_gt_count // 2, "half the remaining grants should still exist."
468467

469468

470-
@pytest.mark.django_db(databases=test_database_names)
469+
@pytest.mark.django_db
471470
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
472471
def test_id_token_methods(oidc_tokens, rf):
473472
id_token = IDToken.objects.get()
@@ -502,7 +501,7 @@ def test_id_token_methods(oidc_tokens, rf):
502501
assert IDToken.objects.filter(jti=id_token.jti).count() == 0
503502

504503

505-
@pytest.mark.django_db(databases=test_database_names)
504+
@pytest.mark.django_db
506505
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
507506
def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf):
508507
id_token = IDToken.objects.get()
@@ -541,7 +540,7 @@ def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf):
541540
assert not IDToken.objects.filter(jti=id_token.jti).exists()
542541

543542

544-
@pytest.mark.django_db(databases=test_database_names)
543+
@pytest.mark.django_db
545544
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
546545
def test_application_key(oauth2_settings, application):
547546
# RS256 key
@@ -566,7 +565,7 @@ def test_application_key(oauth2_settings, application):
566565
assert "This application does not support signed tokens" == str(exc.value)
567566

568567

569-
@pytest.mark.django_db(databases=test_database_names)
568+
@pytest.mark.django_db
570569
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
571570
def test_application_clean(oauth2_settings, application):
572571
# RS256, RSA key is configured
@@ -606,15 +605,15 @@ def test_application_clean(oauth2_settings, application):
606605
application.clean()
607606

608607

609-
@pytest.mark.django_db(databases=test_database_names)
608+
@pytest.mark.django_db
610609
@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_DEFAULT)
611610
def test_application_origin_allowed_default_https(oauth2_settings, cors_application):
612611
"""Test that http schemes are not allowed because ALLOWED_SCHEMES allows only https"""
613612
assert cors_application.origin_allowed("https://example.com")
614613
assert not cors_application.origin_allowed("http://example.com")
615614

616615

617-
@pytest.mark.django_db(databases=test_database_names)
616+
@pytest.mark.django_db
618617
@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_HTTP)
619618
def test_application_origin_allowed_http(oauth2_settings, cors_application):
620619
"""Test that http schemes are allowed because http was added to ALLOWED_SCHEMES"""

tests/test_oauth2_validators.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from . import presets
1818
from .common_testing import OAuth2ProviderTestCase as TestCase
1919
from .common_testing import OAuth2ProviderTransactionTestCase as TransactionTestCase
20-
from .common_testing import test_database_names
2120
from .utils import get_basic_auth_header
2221

2322

@@ -547,7 +546,7 @@ def test_get_jwt_bearer_token(oauth2_settings, mocker):
547546
assert mock_get_id_token.call_args[1] == {}
548547

549548

550-
@pytest.mark.django_db(databases=test_database_names)
549+
@pytest.mark.django_db
551550
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
552551
def test_validate_id_token_expired_jwt(oauth2_settings, mocker, oidc_tokens):
553552
mocker.patch("oauth2_provider.oauth2_validators.jwt.JWT", side_effect=jwt.JWTExpired)
@@ -563,7 +562,7 @@ def test_validate_id_token_no_token(oauth2_settings, mocker):
563562
assert status is False
564563

565564

566-
@pytest.mark.django_db(databases=test_database_names)
565+
@pytest.mark.django_db
567566
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
568567
def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens):
569568
oidc_tokens.application.delete()
@@ -572,7 +571,7 @@ def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens):
572571
assert status is False
573572

574573

575-
@pytest.mark.django_db(databases=test_database_names)
574+
@pytest.mark.django_db
576575
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
577576
def test_validate_id_token_bad_token_no_aud(oauth2_settings, mocker, oidc_key):
578577
token = jwt.JWT(header=json.dumps({"alg": "RS256"}), claims=json.dumps({"bad": "token"}))

0 commit comments

Comments
 (0)