Skip to content

Commit 86b2519

Browse files
committed
Use retrieven_current_databases in django_db marked tests.
1 parent 8072cc7 commit 86b2519

File tree

4 files changed

+45
-41
lines changed

4 files changed

+45
-41
lines changed

tests/test_hybrid.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from . import presets
2424
from .common_testing import OAuth2ProviderTestCase as TestCase
25+
from .common_testing import retrieve_current_databases
2526
from .utils import get_basic_auth_header, spy_on
2627

2728

@@ -1319,7 +1320,7 @@ def test_pre_auth_default_scopes(self):
13191320
self.assertEqual(form["client_id"].value(), self.application.client_id)
13201321

13211322

1322-
@pytest.mark.django_db
1323+
@pytest.mark.django_db(databases=retrieve_current_databases())
13231324
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
13241325
def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_application, client, oidc_key):
13251326
client.force_login(test_user)
@@ -1368,7 +1369,7 @@ def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_app
13681369
assert claims["nonce"] == "random_nonce_string"
13691370

13701371

1371-
@pytest.mark.django_db
1372+
@pytest.mark.django_db(databases=retrieve_current_databases())
13721373
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
13731374
def test_claims_passed_to_code_generation(
13741375
oauth2_settings, test_user, hybrid_application, client, mocker, oidc_key

tests/test_models.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from . import presets
2222
from .common_testing import OAuth2ProviderTestCase as TestCase
23+
from .common_testing import retrieve_current_databases
2324

2425

2526
CLEARTEXT_SECRET = "1234567890abcdefghijklmnopqrstuvwxyz"
@@ -466,7 +467,7 @@ def test_clear_expired_tokens_with_tokens(self):
466467
assert remaining_gt_count == initial_gt_count // 2, "half the remaining grants should still exist."
467468

468469

469-
@pytest.mark.django_db
470+
@pytest.mark.django_db(databases=retrieve_current_databases())
470471
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
471472
def test_id_token_methods(oidc_tokens, rf):
472473
id_token = IDToken.objects.get()
@@ -501,7 +502,7 @@ def test_id_token_methods(oidc_tokens, rf):
501502
assert IDToken.objects.filter(jti=id_token.jti).count() == 0
502503

503504

504-
@pytest.mark.django_db
505+
@pytest.mark.django_db(databases=retrieve_current_databases())
505506
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
506507
def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf):
507508
id_token = IDToken.objects.get()
@@ -540,7 +541,7 @@ def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf):
540541
assert not IDToken.objects.filter(jti=id_token.jti).exists()
541542

542543

543-
@pytest.mark.django_db
544+
@pytest.mark.django_db(databases=retrieve_current_databases())
544545
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
545546
def test_application_key(oauth2_settings, application):
546547
# RS256 key
@@ -565,7 +566,7 @@ def test_application_key(oauth2_settings, application):
565566
assert "This application does not support signed tokens" == str(exc.value)
566567

567568

568-
@pytest.mark.django_db
569+
@pytest.mark.django_db(databases=retrieve_current_databases())
569570
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
570571
def test_application_clean(oauth2_settings, application):
571572
# RS256, RSA key is configured
@@ -605,15 +606,15 @@ def test_application_clean(oauth2_settings, application):
605606
application.clean()
606607

607608

608-
@pytest.mark.django_db
609+
@pytest.mark.django_db(databases=retrieve_current_databases())
609610
@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_DEFAULT)
610611
def test_application_origin_allowed_default_https(oauth2_settings, cors_application):
611612
"""Test that http schemes are not allowed because ALLOWED_SCHEMES allows only https"""
612613
assert cors_application.origin_allowed("https://example.com")
613614
assert not cors_application.origin_allowed("http://example.com")
614615

615616

616-
@pytest.mark.django_db
617+
@pytest.mark.django_db(databases=retrieve_current_databases())
617618
@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_HTTP)
618619
def test_application_origin_allowed_http(oauth2_settings, cors_application):
619620
"""Test that http schemes are allowed because http was added to ALLOWED_SCHEMES"""

tests/test_oauth2_validators.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from . import presets
1818
from .common_testing import OAuth2ProviderTestCase as TestCase
1919
from .common_testing import OAuth2ProviderTransactionTestCase as TransactionTestCase
20+
from .common_testing import retrieve_current_databases
2021
from .utils import get_basic_auth_header
2122

2223

@@ -546,7 +547,7 @@ def test_get_jwt_bearer_token(oauth2_settings, mocker):
546547
assert mock_get_id_token.call_args[1] == {}
547548

548549

549-
@pytest.mark.django_db
550+
@pytest.mark.django_db(databases=retrieve_current_databases())
550551
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
551552
def test_validate_id_token_expired_jwt(oauth2_settings, mocker, oidc_tokens):
552553
mocker.patch("oauth2_provider.oauth2_validators.jwt.JWT", side_effect=jwt.JWTExpired)
@@ -562,7 +563,7 @@ def test_validate_id_token_no_token(oauth2_settings, mocker):
562563
assert status is False
563564

564565

565-
@pytest.mark.django_db
566+
@pytest.mark.django_db(databases=retrieve_current_databases())
566567
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
567568
def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens):
568569
oidc_tokens.application.delete()
@@ -571,7 +572,7 @@ def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens):
571572
assert status is False
572573

573574

574-
@pytest.mark.django_db
575+
@pytest.mark.django_db(databases=retrieve_current_databases())
575576
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW)
576577
def test_validate_id_token_bad_token_no_aud(oauth2_settings, mocker, oidc_key):
577578
token = jwt.JWT(header=json.dumps({"alg": "RS256"}), claims=json.dumps({"bad": "token"}))

tests/test_oidc_views.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from . import presets
2121
from .common_testing import OAuth2ProviderTestCase as TestCase
22+
from .common_testing import retrieve_current_databases
2223

2324

2425
@pytest.mark.usefixtures("oauth2_settings")
@@ -221,7 +222,7 @@ def mock_request_for(user):
221222
return request
222223

223224

224-
@pytest.mark.django_db
225+
@pytest.mark.django_db(databases=retrieve_current_databases())
225226
def test_validate_logout_request(oidc_tokens, public_application, rp_settings):
226227
oidc_tokens = oidc_tokens
227228
application = oidc_tokens.application
@@ -299,7 +300,7 @@ def test_validate_logout_request(oidc_tokens, public_application, rp_settings):
299300
)
300301

301302

302-
@pytest.mark.django_db
303+
@pytest.mark.django_db(databases=retrieve_current_databases())
303304
@pytest.mark.parametrize("ALWAYS_PROMPT", [True, False])
304305
def test_must_prompt(oidc_tokens, other_user, rp_settings, ALWAYS_PROMPT):
305306
rp_settings.OIDC_RP_INITIATED_LOGOUT_ALWAYS_PROMPT = ALWAYS_PROMPT
@@ -320,14 +321,14 @@ def is_logged_in(client):
320321
return get_user(client).is_authenticated
321322

322323

323-
@pytest.mark.django_db
324+
@pytest.mark.django_db(databases=retrieve_current_databases())
324325
def test_rp_initiated_logout_get(logged_in_client, rp_settings):
325326
rsp = logged_in_client.get(reverse("oauth2_provider:rp-initiated-logout"), data={})
326327
assert rsp.status_code == 200
327328
assert is_logged_in(logged_in_client)
328329

329330

330-
@pytest.mark.django_db
331+
@pytest.mark.django_db(databases=retrieve_current_databases())
331332
def test_rp_initiated_logout_get_id_token(logged_in_client, oidc_tokens, rp_settings):
332333
rsp = logged_in_client.get(
333334
reverse("oauth2_provider:rp-initiated-logout"), data={"id_token_hint": oidc_tokens.id_token}
@@ -337,7 +338,7 @@ def test_rp_initiated_logout_get_id_token(logged_in_client, oidc_tokens, rp_sett
337338
assert not is_logged_in(logged_in_client)
338339

339340

340-
@pytest.mark.django_db
341+
@pytest.mark.django_db(databases=retrieve_current_databases())
341342
def test_rp_initiated_logout_get_revoked_id_token(logged_in_client, oidc_tokens, rp_settings):
342343
validator = oauth2_settings.OAUTH2_VALIDATOR_CLASS()
343344
validator._load_id_token(oidc_tokens.id_token).revoke()
@@ -348,7 +349,7 @@ def test_rp_initiated_logout_get_revoked_id_token(logged_in_client, oidc_tokens,
348349
assert is_logged_in(logged_in_client)
349350

350351

351-
@pytest.mark.django_db
352+
@pytest.mark.django_db(databases=retrieve_current_databases())
352353
def test_rp_initiated_logout_get_id_token_redirect(logged_in_client, oidc_tokens, rp_settings):
353354
rsp = logged_in_client.get(
354355
reverse("oauth2_provider:rp-initiated-logout"),
@@ -359,7 +360,7 @@ def test_rp_initiated_logout_get_id_token_redirect(logged_in_client, oidc_tokens
359360
assert not is_logged_in(logged_in_client)
360361

361362

362-
@pytest.mark.django_db
363+
@pytest.mark.django_db(databases=retrieve_current_databases())
363364
def test_rp_initiated_logout_get_id_token_redirect_with_state(logged_in_client, oidc_tokens, rp_settings):
364365
rsp = logged_in_client.get(
365366
reverse("oauth2_provider:rp-initiated-logout"),
@@ -374,7 +375,7 @@ def test_rp_initiated_logout_get_id_token_redirect_with_state(logged_in_client,
374375
assert not is_logged_in(logged_in_client)
375376

376377

377-
@pytest.mark.django_db
378+
@pytest.mark.django_db(databases=retrieve_current_databases())
378379
def test_rp_initiated_logout_get_id_token_missmatch_client_id(
379380
logged_in_client, oidc_tokens, public_application, rp_settings
380381
):
@@ -386,7 +387,7 @@ def test_rp_initiated_logout_get_id_token_missmatch_client_id(
386387
assert is_logged_in(logged_in_client)
387388

388389

389-
@pytest.mark.django_db
390+
@pytest.mark.django_db(databases=retrieve_current_databases())
390391
def test_rp_initiated_logout_public_client_redirect_client_id(
391392
logged_in_client, oidc_non_confidential_tokens, public_application, rp_settings
392393
):
@@ -402,7 +403,7 @@ def test_rp_initiated_logout_public_client_redirect_client_id(
402403
assert not is_logged_in(logged_in_client)
403404

404405

405-
@pytest.mark.django_db
406+
@pytest.mark.django_db(databases=retrieve_current_databases())
406407
def test_rp_initiated_logout_public_client_strict_redirect_client_id(
407408
logged_in_client, oidc_non_confidential_tokens, public_application, oauth2_settings
408409
):
@@ -419,7 +420,7 @@ def test_rp_initiated_logout_public_client_strict_redirect_client_id(
419420
assert is_logged_in(logged_in_client)
420421

421422

422-
@pytest.mark.django_db
423+
@pytest.mark.django_db(databases=retrieve_current_databases())
423424
def test_rp_initiated_logout_get_client_id(logged_in_client, oidc_tokens, rp_settings):
424425
rsp = logged_in_client.get(
425426
reverse("oauth2_provider:rp-initiated-logout"), data={"client_id": oidc_tokens.application.client_id}
@@ -428,7 +429,7 @@ def test_rp_initiated_logout_get_client_id(logged_in_client, oidc_tokens, rp_set
428429
assert is_logged_in(logged_in_client)
429430

430431

431-
@pytest.mark.django_db
432+
@pytest.mark.django_db(databases=retrieve_current_databases())
432433
def test_rp_initiated_logout_post(logged_in_client, oidc_tokens, rp_settings):
433434
form_data = {
434435
"client_id": oidc_tokens.application.client_id,
@@ -438,7 +439,7 @@ def test_rp_initiated_logout_post(logged_in_client, oidc_tokens, rp_settings):
438439
assert is_logged_in(logged_in_client)
439440

440441

441-
@pytest.mark.django_db
442+
@pytest.mark.django_db(databases=retrieve_current_databases())
442443
def test_rp_initiated_logout_post_allowed(logged_in_client, oidc_tokens, rp_settings):
443444
form_data = {"client_id": oidc_tokens.application.client_id, "allow": True}
444445
rsp = logged_in_client.post(reverse("oauth2_provider:rp-initiated-logout"), form_data)
@@ -447,7 +448,7 @@ def test_rp_initiated_logout_post_allowed(logged_in_client, oidc_tokens, rp_sett
447448
assert not is_logged_in(logged_in_client)
448449

449450

450-
@pytest.mark.django_db
451+
@pytest.mark.django_db(databases=retrieve_current_databases())
451452
def test_rp_initiated_logout_post_no_session(client, oidc_tokens, rp_settings):
452453
form_data = {"client_id": oidc_tokens.application.client_id, "allow": True}
453454
rsp = client.post(reverse("oauth2_provider:rp-initiated-logout"), form_data)
@@ -456,7 +457,7 @@ def test_rp_initiated_logout_post_no_session(client, oidc_tokens, rp_settings):
456457
assert not is_logged_in(client)
457458

458459

459-
@pytest.mark.django_db
460+
@pytest.mark.django_db(databases=retrieve_current_databases())
460461
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT)
461462
def test_rp_initiated_logout_expired_tokens_accept(logged_in_client, application, expired_id_token):
462463
# Accepting expired (but otherwise valid and signed by us) tokens is enabled. Logout should go through.
@@ -471,7 +472,7 @@ def test_rp_initiated_logout_expired_tokens_accept(logged_in_client, application
471472
assert not is_logged_in(logged_in_client)
472473

473474

474-
@pytest.mark.django_db
475+
@pytest.mark.django_db(databases=retrieve_current_databases())
475476
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED)
476477
def test_rp_initiated_logout_expired_tokens_deny(logged_in_client, application, expired_id_token):
477478
# Expired tokens should not be accepted by default.
@@ -486,30 +487,30 @@ def test_rp_initiated_logout_expired_tokens_deny(logged_in_client, application,
486487
assert is_logged_in(logged_in_client)
487488

488489

489-
@pytest.mark.django_db
490+
@pytest.mark.django_db(databases=retrieve_current_databases())
490491
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT)
491492
def test_load_id_token_accept_expired(expired_id_token):
492493
id_token, _ = _load_id_token(expired_id_token)
493494
assert isinstance(id_token, get_id_token_model())
494495

495496

496-
@pytest.mark.django_db
497+
@pytest.mark.django_db(databases=retrieve_current_databases())
497498
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT)
498499
def test_load_id_token_wrong_aud(id_token_wrong_aud):
499500
id_token, claims = _load_id_token(id_token_wrong_aud)
500501
assert id_token is None
501502
assert claims is None
502503

503504

504-
@pytest.mark.django_db
505+
@pytest.mark.django_db(databases=retrieve_current_databases())
505506
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED)
506507
def test_load_id_token_deny_expired(expired_id_token):
507508
id_token, claims = _load_id_token(expired_id_token)
508509
assert id_token is None
509510
assert claims is None
510511

511512

512-
@pytest.mark.django_db
513+
@pytest.mark.django_db(databases=retrieve_current_databases())
513514
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT)
514515
def test_validate_claims_wrong_iss(id_token_wrong_iss):
515516
id_token, claims = _load_id_token(id_token_wrong_iss)
@@ -518,15 +519,15 @@ def test_validate_claims_wrong_iss(id_token_wrong_iss):
518519
assert not _validate_claims(mock_request(), claims)
519520

520521

521-
@pytest.mark.django_db
522+
@pytest.mark.django_db(databases=retrieve_current_databases())
522523
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT)
523524
def test_validate_claims(oidc_tokens):
524525
id_token, claims = _load_id_token(oidc_tokens.id_token)
525526
assert claims is not None
526527
assert _validate_claims(mock_request_for(oidc_tokens.user), claims)
527528

528529

529-
@pytest.mark.django_db
530+
@pytest.mark.django_db(databases=retrieve_current_databases())
530531
@pytest.mark.parametrize("method", ["get", "post"])
531532
def test_userinfo_endpoint(oidc_tokens, client, method):
532533
auth_header = "Bearer %s" % oidc_tokens.access_token
@@ -539,7 +540,7 @@ def test_userinfo_endpoint(oidc_tokens, client, method):
539540
assert data["sub"] == str(oidc_tokens.user.pk)
540541

541542

542-
@pytest.mark.django_db
543+
@pytest.mark.django_db(databases=retrieve_current_databases())
543544
def test_userinfo_endpoint_bad_token(oidc_tokens, client):
544545
# No access token
545546
rsp = client.get(reverse("oauth2_provider:user-info"))
@@ -552,7 +553,7 @@ def test_userinfo_endpoint_bad_token(oidc_tokens, client):
552553
assert rsp.status_code == 401
553554

554555

555-
@pytest.mark.django_db
556+
@pytest.mark.django_db(databases=retrieve_current_databases())
556557
def test_token_deletion_on_logout(oidc_tokens, logged_in_client, rp_settings):
557558
AccessToken = get_access_token_model()
558559
IDToken = get_id_token_model()
@@ -575,7 +576,7 @@ def test_token_deletion_on_logout(oidc_tokens, logged_in_client, rp_settings):
575576
assert all([token.revoked <= timezone.now() for token in RefreshToken.objects.all()])
576577

577578

578-
@pytest.mark.django_db
579+
@pytest.mark.django_db(databases=retrieve_current_databases())
579580
def test_token_deletion_on_logout_expired_session(oidc_tokens, client, rp_settings):
580581
AccessToken = get_access_token_model()
581582
IDToken = get_id_token_model()
@@ -616,7 +617,7 @@ def test_token_deletion_on_logout_expired_session(oidc_tokens, client, rp_settin
616617
assert all(token.revoked <= timezone.now() for token in RefreshToken.objects.all())
617618

618619

619-
@pytest.mark.django_db
620+
@pytest.mark.django_db(databases=retrieve_current_databases())
620621
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_KEEP_TOKENS)
621622
def test_token_deletion_on_logout_disabled(oidc_tokens, logged_in_client, rp_settings):
622623
rp_settings.OIDC_RP_INITIATED_LOGOUT_DELETE_TOKENS = False
@@ -652,7 +653,7 @@ def claim_user_email(request):
652653
return EXAMPLE_EMAIL
653654

654655

655-
@pytest.mark.django_db
656+
@pytest.mark.django_db(databases=retrieve_current_databases())
656657
def test_userinfo_endpoint_custom_claims_callable(oidc_tokens, client, oauth2_settings):
657658
class CustomValidator(OAuth2Validator):
658659
oidc_claim_scope = None
@@ -680,7 +681,7 @@ def get_additional_claims(self):
680681
assert data["email"] == EXAMPLE_EMAIL
681682

682683

683-
@pytest.mark.django_db
684+
@pytest.mark.django_db(databases=retrieve_current_databases())
684685
def test_userinfo_endpoint_custom_claims_email_scope_callable(
685686
oidc_email_scope_tokens, client, oauth2_settings
686687
):
@@ -707,7 +708,7 @@ def get_additional_claims(self):
707708
assert data["email"] == EXAMPLE_EMAIL
708709

709710

710-
@pytest.mark.django_db
711+
@pytest.mark.django_db(databases=retrieve_current_databases())
711712
def test_userinfo_endpoint_custom_claims_plain(oidc_tokens, client, oauth2_settings):
712713
class CustomValidator(OAuth2Validator):
713714
oidc_claim_scope = None
@@ -735,7 +736,7 @@ def get_additional_claims(self, request):
735736
assert data["email"] == EXAMPLE_EMAIL
736737

737738

738-
@pytest.mark.django_db
739+
@pytest.mark.django_db(databases=retrieve_current_databases())
739740
def test_userinfo_endpoint_custom_claims_email_scopeplain(oidc_email_scope_tokens, client, oauth2_settings):
740741
class CustomValidator(OAuth2Validator):
741742
def get_additional_claims(self, request):

0 commit comments

Comments
 (0)