Skip to content

Commit 2de4f22

Browse files
committed
fix(strategy): move type handling for storage to property
Avoid putting the burden further down, rather wrap this (unusual) case directly in the class.
1 parent 377b9e4 commit 2de4f22

File tree

11 files changed

+51
-67
lines changed

11 files changed

+51
-67
lines changed

social_core/actions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from collections.abc import Callable
1616

1717
from .backends.base import BaseAuth
18-
from .storage import BaseStorage, UserProtocol
18+
from .storage import UserProtocol
1919
from .strategy import HttpResponseProtocol
2020

2121

@@ -80,7 +80,7 @@ def do_complete( # noqa: C901,PLR0912
8080

8181
# check if the output value is something else than a user and just
8282
# return it to the client
83-
user_model = cast("type[BaseStorage]", backend.strategy.storage).user.user_model()
83+
user_model = backend.strategy.storage.user.user_model()
8484
if authenticated_user and not isinstance(authenticated_user, user_model):
8585
return cast("HttpResponseProtocol", authenticated_user)
8686

social_core/backends/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from requests import Response
1717
from requests.auth import AuthBase
1818

19-
from social_core.storage import BaseStorage, PartialMixin, UserProtocol
19+
from social_core.storage import PartialMixin, UserProtocol
2020
from social_core.strategy import BaseStrategy, HttpResponseProtocol
2121

2222

@@ -122,7 +122,7 @@ def pipeline(
122122
def disconnect(self, *args, **kwargs) -> dict:
123123
pipeline = self.strategy.get_disconnect_pipeline(self)
124124
kwargs["name"] = self.name
125-
kwargs["user_storage"] = cast("type[BaseStorage]", self.strategy.storage).user
125+
kwargs["user_storage"] = self.strategy.storage.user
126126
return self.run_pipeline(pipeline, *args, **kwargs)
127127

128128
def run_pipeline(

social_core/backends/discourse.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,13 @@
22
import time
33
from base64 import urlsafe_b64decode, urlsafe_b64encode
44
from hashlib import sha256
5-
from typing import TYPE_CHECKING, cast
65
from urllib.parse import urlencode
76

87
from social_core.exceptions import AuthException, AuthTokenError
98
from social_core.utils import parse_qs
109

1110
from .base import BaseAuth
1211

13-
if TYPE_CHECKING:
14-
from social_core.storage import BaseStorage
15-
1612

1713
class DiscourseAuth(BaseAuth):
1814
name = "discourse"
@@ -55,17 +51,13 @@ def get_user_details(self, response):
5551
}
5652

5753
def add_nonce(self, nonce) -> None:
58-
cast("type[BaseStorage]", self.strategy.storage).nonce.use(
59-
self.setting("SERVER_URL"), time.time(), nonce
60-
)
54+
self.strategy.storage.nonce.use(self.setting("SERVER_URL"), time.time(), nonce)
6155

6256
def get_nonce(self, nonce):
63-
return cast("type[BaseStorage]", self.strategy.storage).nonce.get(
64-
self.setting("SERVER_URL"), nonce
65-
)
57+
return self.strategy.storage.nonce.get(self.setting("SERVER_URL"), nonce)
6658

6759
def delete_nonce(self, nonce) -> None:
68-
cast("type[BaseStorage]", self.strategy.storage).nonce.delete(nonce)
60+
self.strategy.storage.nonce.delete(nonce)
6961

7062
def auth_complete(self, *args, **kwargs):
7163
"""

social_core/backends/open_id_connect.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
from requests.auth import AuthBase
3232

33-
from social_core.storage import BaseStorage
3433
from social_core.strategy import BaseStrategy
3534

3635

@@ -235,21 +234,19 @@ def get_and_store_nonce(self, url, state):
235234
nonce = self.strategy.random_string(64)
236235
# Store the nonce
237236
association = OpenIdConnectAssociation(nonce, assoc_type=state)
238-
cast("type[BaseStorage]", self.strategy.storage).association.store(
239-
url, association
240-
)
237+
self.strategy.storage.association.store(url, association)
241238
return nonce
242239

243240
def get_nonce(self, nonce):
244241
try:
245-
return cast("type[BaseStorage]", self.strategy.storage).association.get(
242+
return self.strategy.storage.association.get(
246243
server_url=self.authorization_url(), handle=nonce
247244
)[0]
248245
except IndexError:
249246
return None
250247

251248
def remove_nonce(self, nonce_id) -> None:
252-
cast("type[BaseStorage]", self.strategy.storage).association.remove([nonce_id])
249+
self.strategy.storage.association.remove([nonce_id])
253250

254251
def validate_claims(self, id_token) -> None:
255252
utc_timestamp = timegm(datetime.datetime.now(datetime.timezone.utc).timetuple())

social_core/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ class SocialAuthBaseException(ValueError):
1010
"""Base class for pipeline exceptions."""
1111

1212

13+
class SocialAuthImproperlyConfiguredError(SocialAuthBaseException):
14+
"""Raised when configuration is invalid."""
15+
16+
1317
class StrategyMissingFeatureError(SocialAuthBaseException):
1418
"""Strategy does not support this."""
1519

social_core/pipeline/social_auth.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, cast
3+
from typing import TYPE_CHECKING
44

55
from social_core.exceptions import AuthAlreadyAssociated, AuthException, AuthForbidden
66

77
if TYPE_CHECKING:
88
from social_core.backends.base import BaseAuth
9-
from social_core.storage import BaseStorage, UserProtocol
9+
from social_core.storage import UserProtocol
1010

1111

1212
def social_details(backend: BaseAuth, details, response, *args, **kwargs):
@@ -26,9 +26,7 @@ def social_user(
2626
backend: BaseAuth, uid, user: UserProtocol | None = None, *args, **kwargs
2727
):
2828
provider = backend.name
29-
social = cast("type[BaseStorage]", backend.strategy.storage).user.get_social_auth(
30-
provider, uid
31-
)
29+
social = backend.strategy.storage.user.get_social_auth(provider, uid)
3230
if social:
3331
if user and social.user != user:
3432
raise AuthAlreadyAssociated(backend)
@@ -52,14 +50,12 @@ def associate_user(
5250
):
5351
if user and not social:
5452
try:
55-
social = cast(
56-
"type[BaseStorage]", backend.strategy.storage
57-
).user.create_social_auth(user, uid, backend.name)
53+
social = backend.strategy.storage.user.create_social_auth(
54+
user, uid, backend.name
55+
)
5856
# pylint: disable-next=broad-exception-caught
5957
except Exception as err:
60-
if not cast(
61-
"type[BaseStorage]", backend.strategy.storage
62-
).is_integrity_error(err):
58+
if not backend.strategy.storage.is_integrity_error(err):
6359
raise
6460
# Protect for possible race condition, those bastard with FTL
6561
# clicking capabilities, check issue #131:
@@ -95,11 +91,7 @@ def associate_by_email(
9591
# Try to associate accounts registered with the same email address,
9692
# only if it's a single object. AuthException is raised if multiple
9793
# objects are returned.
98-
users = list(
99-
cast("type[BaseStorage]", backend.strategy.storage).user.get_users_by_email(
100-
email
101-
)
102-
)
94+
users = list(backend.strategy.storage.user.get_users_by_email(email))
10395
if len(users) == 0:
10496
return None
10597
if len(users) > 1:
@@ -119,9 +111,9 @@ def load_extra_data(
119111
*args,
120112
**kwargs,
121113
) -> None:
122-
social = kwargs.get("social") or cast(
123-
"type[BaseStorage]", backend.strategy.storage
124-
).user.get_social_auth(backend.name, uid)
114+
social = kwargs.get("social") or backend.strategy.storage.user.get_social_auth(
115+
backend.name, uid
116+
)
125117
if social:
126118
extra_data = backend.extra_data(user, uid, response, details, kwargs)
127119
social.set_extra_data(extra_data)

social_core/pipeline/user.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
if TYPE_CHECKING:
1212
from social_core.backends.base import BaseAuth
13-
from social_core.storage import BaseStorage, UserProtocol
13+
from social_core.storage import UserProtocol
1414
from social_core.strategy import BaseStrategy
1515

1616
USER_FIELDS = ["username", "email"]
@@ -173,4 +173,4 @@ def user_details(
173173
setattr(user, name, value)
174174

175175
if changed:
176-
cast("type[BaseStorage]", strategy.storage).user.changed(user)
176+
strategy.storage.user.changed(user)

social_core/pipeline/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,7 @@ def partial_load(strategy: BaseStrategy, token: str) -> PartialMixin | None:
8989
).user.get_social_auth(**social) # type: ignore[missing-argument]
9090

9191
if user:
92-
kwargs["user"] = cast("type[BaseStorage]", strategy.storage).user.get_user(
93-
user
94-
)
92+
kwargs["user"] = strategy.storage.user.get_user(user)
9593

9694
partial.args = [strategy.from_session_value(val) for val in args]
9795
partial.kwargs = {

social_core/strategy.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from typing import TYPE_CHECKING, Any, Protocol, cast
55

66
from .backends.utils import get_backend
7-
from .exceptions import StrategyMissingBackendError, StrategyMissingFeatureError
7+
from .exceptions import (
8+
SocialAuthImproperlyConfiguredError,
9+
StrategyMissingBackendError,
10+
StrategyMissingFeatureError,
11+
)
812
from .pipeline import DEFAULT_AUTH_PIPELINE, DEFAULT_DISCONNECT_PIPELINE
913
from .pipeline.utils import partial_load
1014
from .store import OpenIdSessionWrapper, OpenIdStore
@@ -53,9 +57,15 @@ def __init__(
5357
storage: type[BaseStorage] | None = None,
5458
tpl: type[BaseTemplateStrategy] | None = None,
5559
) -> None:
56-
self.storage = storage
60+
self._storage = storage
5761
self.tpl = (tpl or self.DEFAULT_TEMPLATE_STRATEGY)(self)
5862

63+
@property
64+
def storage(self) -> type[BaseStorage]:
65+
if self._storage is None:
66+
raise StrategyMissingBackendError
67+
return self._storage
68+
5969
def setting(self, name: str, default=None, backend: BaseAuth | None = None):
6070
names = [setting_name(name), name]
6171
if backend:
@@ -68,13 +78,9 @@ def setting(self, name: str, default=None, backend: BaseAuth | None = None):
6878
return default
6979

7080
def create_user(self, *args, **kwargs):
71-
if self.storage is None:
72-
raise StrategyMissingBackendError
7381
return self.storage.user.create_user(*args, **kwargs)
7482

7583
def get_user(self, *args, **kwargs):
76-
if self.storage is None:
77-
raise StrategyMissingBackendError
7884
return self.storage.user.get_user(*args, **kwargs)
7985

8086
def session_setdefault(self, name: str, value):
@@ -121,8 +127,6 @@ def partial_load(self, token: str) -> PartialMixin | None:
121127
return partial_load(self, token)
122128

123129
def clean_partial_pipeline(self, token) -> None:
124-
if self.storage is None:
125-
raise StrategyMissingBackendError
126130
self.storage.partial.destroy(token)
127131
current_token_in_session = self.session_get(PARTIAL_TOKEN_SESSION_NAME)
128132
if current_token_in_session == token:
@@ -158,17 +162,17 @@ def get_language(self) -> str:
158162
def send_email_validation(
159163
self, backend: BaseAuth, email: str, partial_token: str | None = None
160164
) -> CodeMixin:
161-
if self.storage is None:
162-
raise StrategyMissingBackendError
163165
email_validation = self.setting("EMAIL_VALIDATION_FUNCTION")
166+
if not email_validation:
167+
raise SocialAuthImproperlyConfiguredError(
168+
"EMAIL_VALIDATION_FUNCTION missing"
169+
)
164170
send_email = module_member(email_validation)
165171
code = self.storage.code.make_code(email)
166172
send_email(self, backend, code, partial_token)
167173
return code
168174

169175
def validate_email(self, email: str, code: str) -> bool:
170-
if self.storage is None:
171-
raise StrategyMissingBackendError
172176
verification_code = self.storage.code.get_code(code)
173177
if not verification_code or verification_code.code != code:
174178
return False
@@ -193,8 +197,6 @@ def authenticate(
193197
) -> UserProtocol | HttpResponseProtocol | None:
194198
"""Trigger the authentication mechanism tied to the current
195199
framework"""
196-
if self.storage is None:
197-
raise StrategyMissingBackendError
198200
kwargs["strategy"] = self
199201
kwargs["storage"] = self.storage
200202
kwargs["backend"] = backend

social_core/tests/strategy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
if TYPE_CHECKING:
88
from social_core.backends.base import BaseAuth
9-
from social_core.storage import BaseStorage
109

1110
TEST_URI = "http://myapp.com"
1211
TEST_HOST = "myapp.com"
@@ -30,7 +29,6 @@ def render_string(self, html, context):
3029

3130
class TestStrategy(BaseStrategy):
3231
__test__ = False
33-
storage: type[BaseStorage]
3432

3533
DEFAULT_TEMPLATE_STRATEGY = TestTemplateStrategy
3634

0 commit comments

Comments
 (0)