Skip to content

Commit adf7486

Browse files
committed
feat: provide way to configure default strategy
This is to provide app modules way to provide default strategy for the backends. Django app currently handles this by monkey-patching the constructor in baseauth_init_workaround. This triggered several type annotation fixes as well. Fixes #1490
1 parent 9c93ae2 commit adf7486

File tree

12 files changed

+117
-38
lines changed

12 files changed

+117
-38
lines changed

social_core/actions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
if TYPE_CHECKING:
1515
from .backends.base import BaseAuth
16-
from .storage import UserProtocol
16+
from .storage import BaseStorage, UserProtocol
1717
from .strategy import HttpResponseProtocol
1818

1919

@@ -72,7 +72,7 @@ def do_complete(
7272

7373
# check if the output value is something else than a user and just
7474
# return it to the client
75-
user_model = backend.strategy.storage.user.user_model()
75+
user_model = cast("type[BaseStorage]", backend.strategy.storage).user.user_model()
7676
if user and not isinstance(user, user_model):
7777
return cast("HttpResponseProtocol", user)
7878

@@ -160,7 +160,7 @@ def do_disconnect(
160160
)
161161

162162
if isinstance(response, dict):
163-
url = backend.strategy.absolute_uri(
163+
url: str = backend.strategy.absolute_uri(
164164
backend.strategy.request_data().get(redirect_name, "")
165165
or backend.setting("DISCONNECT_REDIRECT_URL")
166166
or backend.setting("LOGIN_REDIRECT_URL")

social_core/backends/base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import requests
88

99
from social_core.exceptions import AuthConnectionError, AuthUnknownError
10+
from social_core.registry import REGISTRY
1011
from social_core.utils import module_member, parse_qs, social_logger, user_agent
1112

1213
if TYPE_CHECKING:
@@ -15,8 +16,8 @@
1516
from requests import Response
1617
from requests.auth import AuthBase
1718

18-
from social_core.storage import UserProtocol
19-
from social_core.strategy import HttpResponseProtocol
19+
from social_core.storage import BaseStorage, UserProtocol
20+
from social_core.strategy import BaseStrategy, HttpResponseProtocol
2021

2122

2223
class BaseAuth:
@@ -31,8 +32,10 @@ class BaseAuth:
3132
REQUIRES_EMAIL_VALIDATION = False
3233
SEND_USER_AGENT = True
3334

34-
def __init__(self, strategy, redirect_uri=None) -> None:
35-
self.strategy = strategy
35+
def __init__(
36+
self, strategy: BaseStrategy | None = None, redirect_uri: str | None = None
37+
) -> None:
38+
self.strategy = strategy if strategy is not None else REGISTRY.default_strategy
3639
self.redirect_uri = redirect_uri
3740
self.data = self.strategy.request_data()
3841
self.redirect_uri = self.strategy.absolute_uri(self.redirect_uri)
@@ -115,7 +118,7 @@ def pipeline(
115118
def disconnect(self, *args, **kwargs) -> dict:
116119
pipeline = self.strategy.get_disconnect_pipeline(self)
117120
kwargs["name"] = self.name
118-
kwargs["user_storage"] = self.strategy.storage.user
121+
kwargs["user_storage"] = cast("type[BaseStorage]", self.strategy.storage).user
119122
return self.run_pipeline(pipeline, *args, **kwargs)
120123

121124
def run_pipeline(

social_core/backends/discourse.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22
import time
33
from base64 import urlsafe_b64decode, urlsafe_b64encode
44
from hashlib import sha256
5+
from typing import TYPE_CHECKING, cast
56
from urllib.parse import urlencode
67

78
from social_core.exceptions import AuthException, AuthTokenError
89
from social_core.utils import parse_qs
910

1011
from .base import BaseAuth
1112

13+
if TYPE_CHECKING:
14+
from social_core.storage import BaseStorage
15+
1216

1317
class DiscourseAuth(BaseAuth):
1418
name = "discourse"
@@ -51,13 +55,17 @@ def get_user_details(self, response):
5155
}
5256

5357
def add_nonce(self, nonce) -> None:
54-
self.strategy.storage.nonce.use(self.setting("SERVER_URL"), time.time(), nonce)
58+
cast("type[BaseStorage]", self.strategy.storage).nonce.use(
59+
self.setting("SERVER_URL"), time.time(), nonce
60+
)
5561

5662
def get_nonce(self, nonce):
57-
return self.strategy.storage.nonce.get(self.setting("SERVER_URL"), nonce)
63+
return cast("type[BaseStorage]", self.strategy.storage).nonce.get(
64+
self.setting("SERVER_URL"), nonce
65+
)
5866

5967
def delete_nonce(self, nonce) -> None:
60-
self.strategy.storage.nonce.delete(nonce)
68+
cast("type[BaseStorage]", self.strategy.storage).nonce.delete(nonce)
6169

6270
def auth_complete(self, *args, **kwargs):
6371
"""

social_core/backends/open_id_connect.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import datetime
55
import json
66
from calendar import timegm
7-
from typing import TYPE_CHECKING, Any, Literal
7+
from typing import TYPE_CHECKING, Any, Literal, cast
88

99
import jwt
10+
from cryptography.hazmat.backends import default_backend
11+
from cryptography.hazmat.primitives import hashes
1012
from jwt import (
1113
ExpiredSignatureError,
1214
InvalidAudienceError,
@@ -28,8 +30,8 @@
2830

2931
from requests.auth import AuthBase
3032

31-
from cryptography.hazmat.backends import default_backend
32-
from cryptography.hazmat.primitives import hashes
33+
from social_core.storage import BaseStorage
34+
from social_core.strategy import BaseStrategy
3335

3436

3537
class OpenIdConnectAssociation:
@@ -88,7 +90,9 @@ class OpenIdConnectAuth(BaseOAuth2):
8890
LOGIN_HINT = None
8991
ACR_VALUES = None
9092

91-
def __init__(self, strategy, redirect_uri=None) -> None:
93+
def __init__(
94+
self, strategy: BaseStrategy | None = None, redirect_uri: str | None = None
95+
) -> None:
9296
super().__init__(strategy, redirect_uri=redirect_uri)
9397
self.id_token = None
9498

@@ -231,19 +235,21 @@ def get_and_store_nonce(self, url, state):
231235
nonce = self.strategy.random_string(64)
232236
# Store the nonce
233237
association = OpenIdConnectAssociation(nonce, assoc_type=state)
234-
self.strategy.storage.association.store(url, association)
238+
cast("type[BaseStorage]", self.strategy.storage).association.store(
239+
url, association
240+
)
235241
return nonce
236242

237243
def get_nonce(self, nonce):
238244
try:
239-
return self.strategy.storage.association.get(
245+
return cast("type[BaseStorage]", self.strategy.storage).association.get(
240246
server_url=self.authorization_url(), handle=nonce
241247
)[0]
242248
except IndexError:
243249
pass
244250

245251
def remove_nonce(self, nonce_id) -> None:
246-
self.strategy.storage.association.remove([nonce_id])
252+
cast("type[BaseStorage]", self.strategy.storage).association.remove([nonce_id])
247253

248254
def validate_claims(self, id_token) -> None:
249255
utc_timestamp = timegm(datetime.datetime.now(datetime.timezone.utc).timetuple())

social_core/backends/vk.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def auth_complete(self, *args, **kwargs):
207207
response = {self.id_key(): user_id}
208208
response.update(json.loads(request["api_result"])["response"][0])
209209
return self.strategy.authenticate(
210-
*args,
211210
auth=self,
212211
backend=self,
213212
request=request,

social_core/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ def __str__(self) -> str:
2222
return f"Strategy {self.strategy_name} does not support {self.feature_name}"
2323

2424

25+
class DefaultStrategyMissingError(SocialAuthBaseException):
26+
"""Default strategy is not configured."""
27+
28+
def __str__(self) -> str:
29+
return "Default strategy is not configured"
30+
31+
2532
class StrategyMissingBackendError(SocialAuthBaseException):
2633
"""Strategy storage backend is not configured."""
2734

social_core/pipeline/social_auth.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, cast
44

55
from social_core.exceptions import AuthAlreadyAssociated, AuthException, AuthForbidden
66

77
if TYPE_CHECKING:
88
from social_core.backends.base import BaseAuth
9-
from social_core.storage import UserProtocol
9+
from social_core.storage import BaseStorage, UserProtocol
1010

1111

1212
def social_details(backend: BaseAuth, details, response, *args, **kwargs):
@@ -26,7 +26,9 @@ def social_user(
2626
backend: BaseAuth, uid, user: UserProtocol | None = None, *args, **kwargs
2727
):
2828
provider = backend.name
29-
social = backend.strategy.storage.user.get_social_auth(provider, uid)
29+
social = cast("type[BaseStorage]", backend.strategy.storage).user.get_social_auth(
30+
provider, uid
31+
)
3032
if social:
3133
if user and social.user != user:
3234
raise AuthAlreadyAssociated(backend)
@@ -50,11 +52,13 @@ def associate_user(
5052
):
5153
if user and not social:
5254
try:
53-
social = backend.strategy.storage.user.create_social_auth(
54-
user, uid, backend.name
55-
)
55+
social = cast(
56+
"type[BaseStorage]", backend.strategy.storage
57+
).user.create_social_auth(user, uid, backend.name)
5658
except Exception as err:
57-
if not backend.strategy.storage.is_integrity_error(err):
59+
if not cast(
60+
"type[BaseStorage]", backend.strategy.storage
61+
).is_integrity_error(err):
5862
raise
5963
# Protect for possible race condition, those bastard with FTL
6064
# clicking capabilities, check issue #131:
@@ -91,7 +95,11 @@ def associate_by_email(
9195
# Try to associate accounts registered with the same email address,
9296
# only if it's a single object. AuthException is raised if multiple
9397
# objects are returned.
94-
users = list(backend.strategy.storage.user.get_users_by_email(email))
98+
users = list(
99+
cast("type[BaseStorage]", backend.strategy.storage).user.get_users_by_email(
100+
email
101+
)
102+
)
95103
if len(users) == 0:
96104
return None
97105
if len(users) > 1:
@@ -111,9 +119,9 @@ def load_extra_data(
111119
*args,
112120
**kwargs,
113121
) -> None:
114-
social = kwargs.get("social") or backend.strategy.storage.user.get_social_auth(
115-
backend.name, uid
116-
)
122+
social = kwargs.get("social") or cast(
123+
"type[BaseStorage]", backend.strategy.storage
124+
).user.get_social_auth(backend.name, uid)
117125
if social:
118126
extra_data = backend.extra_data(user, uid, response, details, kwargs)
119127
social.set_extra_data(extra_data)

social_core/pipeline/user.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
if TYPE_CHECKING:
1212
from social_core.backends.base import BaseAuth
13-
from social_core.storage import UserProtocol
13+
from social_core.storage import BaseStorage, UserProtocol
1414
from social_core.strategy import BaseStrategy
1515

1616
USER_FIELDS = ["username", "email"]
@@ -173,4 +173,4 @@ def user_details(
173173
setattr(user, name, value)
174174

175175
if changed:
176-
strategy.storage.user.changed(user)
176+
cast("type[BaseStorage]", strategy.storage).user.changed(user)

social_core/pipeline/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, cast
44

55
from social_core.exceptions import (
66
StrategyMissingBackendError,
77
)
88

99
if TYPE_CHECKING:
1010
from social_core.backends.base import BaseAuth
11-
from social_core.storage import PartialMixin, UserProtocol
11+
from social_core.storage import BaseStorage, PartialMixin, UserProtocol
1212
from social_core.strategy import BaseStrategy
1313

1414
SERIALIZABLE_TYPES = (dict, list, tuple, set, bool, type(None), int, str, bytes)
@@ -84,10 +84,14 @@ def partial_load(strategy: BaseStrategy, token: str) -> PartialMixin | None:
8484
social = kwargs.get("social")
8585

8686
if isinstance(social, dict):
87-
kwargs["social"] = strategy.storage.user.get_social_auth(**social) # type: ignore[missing-argument]
87+
kwargs["social"] = cast(
88+
"type[BaseStorage]", strategy.storage
89+
).user.get_social_auth(**social) # type: ignore[missing-argument]
8890

8991
if user:
90-
kwargs["user"] = strategy.storage.user.get_user(user)
92+
kwargs["user"] = cast("type[BaseStorage]", strategy.storage).user.get_user(
93+
user
94+
)
9195

9296
partial.args = [strategy.from_session_value(val) for val in args]
9397
partial.kwargs = {

social_core/registry.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from .exceptions import DefaultStrategyMissingError
6+
7+
if TYPE_CHECKING:
8+
from .strategy import BaseStrategy
9+
10+
11+
class Registry:
12+
def __init__(self) -> None:
13+
self._default_strategy: BaseStrategy | None = None
14+
15+
@property
16+
def default_strategy(self) -> BaseStrategy:
17+
if self._default_strategy is None:
18+
raise DefaultStrategyMissingError
19+
return self._default_strategy
20+
21+
@default_strategy.setter
22+
def default_strategy(self, strategy: BaseStrategy) -> None:
23+
self._default_strategy = strategy
24+
25+
26+
REGISTRY = Registry()

0 commit comments

Comments
 (0)