Skip to content

Commit 4f4a51d

Browse files
committed
feat: provide way to configure default strategy
This is to provide app modules way to provide default strategy for the backends. Django app currently handles this by monkey-patching the constructor in baseauth_init_workaround. Fixes #1490
1 parent 9c93ae2 commit 4f4a51d

File tree

5 files changed

+67
-6
lines changed

5 files changed

+67
-6
lines changed

social_core/backends/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import requests
88

99
from social_core.exceptions import AuthConnectionError, AuthUnknownError
10+
from social_core.registry import REGISTRY
11+
from social_core.storage import UserProtocol
1012
from social_core.utils import module_member, parse_qs, social_logger, user_agent
1113

1214
if TYPE_CHECKING:
@@ -31,8 +33,11 @@ class BaseAuth:
3133
REQUIRES_EMAIL_VALIDATION = False
3234
SEND_USER_AGENT = True
3335

34-
def __init__(self, strategy, redirect_uri=None) -> None:
35-
self.strategy = strategy
36+
def __init__(self, strategy=None, redirect_uri: str | None = None) -> None:
37+
# TODO: temporary type override
38+
self.strategy: Any = (
39+
strategy if strategy is not None else REGISTRY.default_strategy
40+
)
3641
self.redirect_uri = redirect_uri
3742
self.data = self.strategy.request_data()
3843
self.redirect_uri = self.strategy.absolute_uri(self.redirect_uri)

social_core/backends/open_id_connect.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import TYPE_CHECKING, Any, Literal
88

99
import jwt
10+
from cryptography.hazmat.backends import default_backend
11+
from cryptography.hazmat.primitives import hashes
1012
from jwt import (
1113
ExpiredSignatureError,
1214
InvalidAudienceError,
@@ -28,9 +30,6 @@
2830

2931
from requests.auth import AuthBase
3032

31-
from cryptography.hazmat.backends import default_backend
32-
from cryptography.hazmat.primitives import hashes
33-
3433

3534
class OpenIdConnectAssociation:
3635
"""Use Association model to save the nonce by force."""
@@ -88,7 +87,7 @@ class OpenIdConnectAuth(BaseOAuth2):
8887
LOGIN_HINT = None
8988
ACR_VALUES = None
9089

91-
def __init__(self, strategy, redirect_uri=None) -> None:
90+
def __init__(self, strategy=None, redirect_uri: str | None = None) -> None:
9291
super().__init__(strategy, redirect_uri=redirect_uri)
9392
self.id_token = None
9493

social_core/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ def __str__(self) -> str:
2222
return f"Strategy {self.strategy_name} does not support {self.feature_name}"
2323

2424

25+
class DefaultStrategyMissingError(SocialAuthBaseException):
26+
"""Default strategy is not configured."""
27+
28+
def __str__(self) -> str:
29+
return "Default strategy is not configured"
30+
31+
2532
class StrategyMissingBackendError(SocialAuthBaseException):
2633
"""Strategy storage backend is not configured."""
2734

social_core/registry.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from .exceptions import DefaultStrategyMissingError
6+
7+
if TYPE_CHECKING:
8+
from .strategy import BaseStrategy
9+
10+
11+
class Registry:
12+
def __init__(self) -> None:
13+
self._default_strategy: BaseStrategy | None = None
14+
15+
def reset(self) -> None:
16+
self._default_strategy = None
17+
18+
@property
19+
def default_strategy(self) -> BaseStrategy:
20+
if self._default_strategy is None:
21+
raise DefaultStrategyMissingError
22+
return self._default_strategy
23+
24+
@default_strategy.setter
25+
def default_strategy(self, strategy: BaseStrategy) -> None:
26+
self._default_strategy = strategy
27+
28+
29+
REGISTRY = Registry()

social_core/tests/test_registry.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import unittest
2+
3+
from social_core.exceptions import (
4+
DefaultStrategyMissingError,
5+
)
6+
from social_core.registry import REGISTRY
7+
8+
from .strategy import TestStrategy
9+
10+
11+
class StrategyRegistryTestCase(unittest.TestCase):
12+
def test_missing(self):
13+
with self.assertRaises(DefaultStrategyMissingError):
14+
self.assertIsNotNone(REGISTRY.default_strategy)
15+
16+
def test_set(self):
17+
REGISTRY.default_strategy = TestStrategy(None)
18+
try:
19+
self.assertIsInstance(REGISTRY.default_strategy, TestStrategy)
20+
finally:
21+
REGISTRY.reset()

0 commit comments

Comments
 (0)