diff --git a/social_core/actions.py b/social_core/actions.py index 3f95bed0..3127e237 100644 --- a/social_core/actions.py +++ b/social_core/actions.py @@ -15,7 +15,7 @@ from collections.abc import Callable from .backends.base import BaseAuth - from .storage import BaseStorage, UserProtocol + from .storage import UserProtocol from .strategy import HttpResponseProtocol @@ -80,7 +80,7 @@ def do_complete( # noqa: C901,PLR0912 # check if the output value is something else than a user and just # return it to the client - user_model = cast("type[BaseStorage]", backend.strategy.storage).user.user_model() + user_model = backend.strategy.storage.user.user_model() if authenticated_user and not isinstance(authenticated_user, user_model): return cast("HttpResponseProtocol", authenticated_user) diff --git a/social_core/backends/base.py b/social_core/backends/base.py index fb151821..398d9723 100644 --- a/social_core/backends/base.py +++ b/social_core/backends/base.py @@ -16,7 +16,7 @@ from requests import Response from requests.auth import AuthBase - from social_core.storage import BaseStorage, PartialMixin, UserProtocol + from social_core.storage import PartialMixin, UserProtocol from social_core.strategy import BaseStrategy, HttpResponseProtocol @@ -122,7 +122,7 @@ def pipeline( def disconnect(self, *args, **kwargs) -> dict: pipeline = self.strategy.get_disconnect_pipeline(self) kwargs["name"] = self.name - kwargs["user_storage"] = cast("type[BaseStorage]", self.strategy.storage).user + kwargs["user_storage"] = self.strategy.storage.user return self.run_pipeline(pipeline, *args, **kwargs) def run_pipeline( diff --git a/social_core/backends/discourse.py b/social_core/backends/discourse.py index 46082be4..2801eaca 100644 --- a/social_core/backends/discourse.py +++ b/social_core/backends/discourse.py @@ -2,7 +2,6 @@ import time from base64 import urlsafe_b64decode, urlsafe_b64encode from hashlib import sha256 -from typing import TYPE_CHECKING, cast from urllib.parse import urlencode from social_core.exceptions import AuthException, AuthTokenError @@ -10,9 +9,6 @@ from .base import BaseAuth -if TYPE_CHECKING: - from social_core.storage import BaseStorage - class DiscourseAuth(BaseAuth): name = "discourse" @@ -55,17 +51,13 @@ def get_user_details(self, response): } def add_nonce(self, nonce) -> None: - cast("type[BaseStorage]", self.strategy.storage).nonce.use( - self.setting("SERVER_URL"), time.time(), nonce - ) + self.strategy.storage.nonce.use(self.setting("SERVER_URL"), time.time(), nonce) def get_nonce(self, nonce): - return cast("type[BaseStorage]", self.strategy.storage).nonce.get( - self.setting("SERVER_URL"), nonce - ) + return self.strategy.storage.nonce.get(self.setting("SERVER_URL"), nonce) def delete_nonce(self, nonce) -> None: - cast("type[BaseStorage]", self.strategy.storage).nonce.delete(nonce) + self.strategy.storage.nonce.delete(nonce) def auth_complete(self, *args, **kwargs): """ diff --git a/social_core/backends/open_id_connect.py b/social_core/backends/open_id_connect.py index c84b6723..ce982a1d 100644 --- a/social_core/backends/open_id_connect.py +++ b/social_core/backends/open_id_connect.py @@ -30,7 +30,6 @@ from requests.auth import AuthBase - from social_core.storage import BaseStorage from social_core.strategy import BaseStrategy @@ -235,21 +234,19 @@ def get_and_store_nonce(self, url, state): nonce = self.strategy.random_string(64) # Store the nonce association = OpenIdConnectAssociation(nonce, assoc_type=state) - cast("type[BaseStorage]", self.strategy.storage).association.store( - url, association - ) + self.strategy.storage.association.store(url, association) return nonce def get_nonce(self, nonce): try: - return cast("type[BaseStorage]", self.strategy.storage).association.get( + return self.strategy.storage.association.get( server_url=self.authorization_url(), handle=nonce )[0] except IndexError: return None def remove_nonce(self, nonce_id) -> None: - cast("type[BaseStorage]", self.strategy.storage).association.remove([nonce_id]) + self.strategy.storage.association.remove([nonce_id]) def validate_claims(self, id_token) -> None: utc_timestamp = timegm(datetime.datetime.now(datetime.timezone.utc).timetuple()) diff --git a/social_core/exceptions.py b/social_core/exceptions.py index 24341a73..91e2666f 100644 --- a/social_core/exceptions.py +++ b/social_core/exceptions.py @@ -10,6 +10,10 @@ class SocialAuthBaseException(ValueError): """Base class for pipeline exceptions.""" +class SocialAuthImproperlyConfiguredError(SocialAuthBaseException): + """Raised when configuration is invalid.""" + + class StrategyMissingFeatureError(SocialAuthBaseException): """Strategy does not support this.""" diff --git a/social_core/pipeline/social_auth.py b/social_core/pipeline/social_auth.py index 639e2f55..c827d286 100644 --- a/social_core/pipeline/social_auth.py +++ b/social_core/pipeline/social_auth.py @@ -1,12 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING from social_core.exceptions import AuthAlreadyAssociated, AuthException, AuthForbidden if TYPE_CHECKING: from social_core.backends.base import BaseAuth - from social_core.storage import BaseStorage, UserProtocol + from social_core.storage import UserProtocol def social_details(backend: BaseAuth, details, response, *args, **kwargs): @@ -26,9 +26,7 @@ def social_user( backend: BaseAuth, uid, user: UserProtocol | None = None, *args, **kwargs ): provider = backend.name - social = cast("type[BaseStorage]", backend.strategy.storage).user.get_social_auth( - provider, uid - ) + social = backend.strategy.storage.user.get_social_auth(provider, uid) if social: if user and social.user != user: raise AuthAlreadyAssociated(backend) @@ -52,14 +50,12 @@ def associate_user( ): if user and not social: try: - social = cast( - "type[BaseStorage]", backend.strategy.storage - ).user.create_social_auth(user, uid, backend.name) + social = backend.strategy.storage.user.create_social_auth( + user, uid, backend.name + ) # pylint: disable-next=broad-exception-caught except Exception as err: - if not cast( - "type[BaseStorage]", backend.strategy.storage - ).is_integrity_error(err): + if not backend.strategy.storage.is_integrity_error(err): raise # Protect for possible race condition, those bastard with FTL # clicking capabilities, check issue #131: @@ -95,11 +91,7 @@ def associate_by_email( # Try to associate accounts registered with the same email address, # only if it's a single object. AuthException is raised if multiple # objects are returned. - users = list( - cast("type[BaseStorage]", backend.strategy.storage).user.get_users_by_email( - email - ) - ) + users = list(backend.strategy.storage.user.get_users_by_email(email)) if len(users) == 0: return None if len(users) > 1: @@ -119,9 +111,9 @@ def load_extra_data( *args, **kwargs, ) -> None: - social = kwargs.get("social") or cast( - "type[BaseStorage]", backend.strategy.storage - ).user.get_social_auth(backend.name, uid) + social = kwargs.get("social") or backend.strategy.storage.user.get_social_auth( + backend.name, uid + ) if social: extra_data = backend.extra_data(user, uid, response, details, kwargs) social.set_extra_data(extra_data) diff --git a/social_core/pipeline/user.py b/social_core/pipeline/user.py index 5e37b923..28c2ce90 100644 --- a/social_core/pipeline/user.py +++ b/social_core/pipeline/user.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from social_core.backends.base import BaseAuth - from social_core.storage import BaseStorage, UserProtocol + from social_core.storage import UserProtocol from social_core.strategy import BaseStrategy USER_FIELDS = ["username", "email"] @@ -173,4 +173,4 @@ def user_details( setattr(user, name, value) if changed: - cast("type[BaseStorage]", strategy.storage).user.changed(user) + strategy.storage.user.changed(user) diff --git a/social_core/pipeline/utils.py b/social_core/pipeline/utils.py index 46ec7e1f..755958d0 100644 --- a/social_core/pipeline/utils.py +++ b/social_core/pipeline/utils.py @@ -89,9 +89,7 @@ def partial_load(strategy: BaseStrategy, token: str) -> PartialMixin | None: ).user.get_social_auth(**social) # type: ignore[missing-argument] if user: - kwargs["user"] = cast("type[BaseStorage]", strategy.storage).user.get_user( - user - ) + kwargs["user"] = strategy.storage.user.get_user(user) partial.args = [strategy.from_session_value(val) for val in args] partial.kwargs = { diff --git a/social_core/strategy.py b/social_core/strategy.py index 4cf5bf3f..f2032a3d 100644 --- a/social_core/strategy.py +++ b/social_core/strategy.py @@ -4,7 +4,11 @@ from typing import TYPE_CHECKING, Any, Protocol, cast from .backends.utils import get_backend -from .exceptions import StrategyMissingBackendError, StrategyMissingFeatureError +from .exceptions import ( + SocialAuthImproperlyConfiguredError, + StrategyMissingBackendError, + StrategyMissingFeatureError, +) from .pipeline import DEFAULT_AUTH_PIPELINE, DEFAULT_DISCONNECT_PIPELINE from .pipeline.utils import partial_load from .store import OpenIdSessionWrapper, OpenIdStore @@ -53,9 +57,15 @@ def __init__( storage: type[BaseStorage] | None = None, tpl: type[BaseTemplateStrategy] | None = None, ) -> None: - self.storage = storage + self._storage = storage self.tpl = (tpl or self.DEFAULT_TEMPLATE_STRATEGY)(self) + @property + def storage(self) -> type[BaseStorage]: + if self._storage is None: + raise StrategyMissingBackendError + return self._storage + def setting(self, name: str, default=None, backend: BaseAuth | None = None): names = [setting_name(name), name] if backend: @@ -68,13 +78,9 @@ def setting(self, name: str, default=None, backend: BaseAuth | None = None): return default def create_user(self, *args, **kwargs): - if self.storage is None: - raise StrategyMissingBackendError return self.storage.user.create_user(*args, **kwargs) def get_user(self, *args, **kwargs): - if self.storage is None: - raise StrategyMissingBackendError return self.storage.user.get_user(*args, **kwargs) def session_setdefault(self, name: str, value): @@ -121,8 +127,6 @@ def partial_load(self, token: str) -> PartialMixin | None: return partial_load(self, token) def clean_partial_pipeline(self, token) -> None: - if self.storage is None: - raise StrategyMissingBackendError self.storage.partial.destroy(token) current_token_in_session = self.session_get(PARTIAL_TOKEN_SESSION_NAME) if current_token_in_session == token: @@ -158,17 +162,17 @@ def get_language(self) -> str: def send_email_validation( self, backend: BaseAuth, email: str, partial_token: str | None = None ) -> CodeMixin: - if self.storage is None: - raise StrategyMissingBackendError email_validation = self.setting("EMAIL_VALIDATION_FUNCTION") + if not email_validation: + raise SocialAuthImproperlyConfiguredError( + "EMAIL_VALIDATION_FUNCTION missing" + ) send_email = module_member(email_validation) code = self.storage.code.make_code(email) send_email(self, backend, code, partial_token) return code def validate_email(self, email: str, code: str) -> bool: - if self.storage is None: - raise StrategyMissingBackendError verification_code = self.storage.code.get_code(code) if not verification_code or verification_code.code != code: return False @@ -193,8 +197,6 @@ def authenticate( ) -> UserProtocol | HttpResponseProtocol | None: """Trigger the authentication mechanism tied to the current framework""" - if self.storage is None: - raise StrategyMissingBackendError kwargs["strategy"] = self kwargs["storage"] = self.storage kwargs["backend"] = backend diff --git a/social_core/tests/strategy.py b/social_core/tests/strategy.py index 39183920..88f50361 100644 --- a/social_core/tests/strategy.py +++ b/social_core/tests/strategy.py @@ -6,7 +6,6 @@ if TYPE_CHECKING: from social_core.backends.base import BaseAuth - from social_core.storage import BaseStorage TEST_URI = "http://myapp.com" TEST_HOST = "myapp.com" @@ -30,7 +29,6 @@ def render_string(self, html, context): class TestStrategy(BaseStrategy): __test__ = False - storage: type[BaseStorage] DEFAULT_TEMPLATE_STRATEGY = TestTemplateStrategy diff --git a/social_core/tests/test_strategy_none_storage.py b/social_core/tests/test_strategy_none_storage.py index 0b2a23a1..bece206b 100644 --- a/social_core/tests/test_strategy_none_storage.py +++ b/social_core/tests/test_strategy_none_storage.py @@ -1,7 +1,10 @@ import unittest from social_core.backends.base import BaseAuth -from social_core.exceptions import StrategyMissingBackendError +from social_core.exceptions import ( + SocialAuthImproperlyConfiguredError, + StrategyMissingBackendError, +) from .strategy import TestStrategy @@ -15,7 +18,8 @@ def setUp(self) -> None: def test_strategy_initialization_with_none(self) -> None: """Test that strategy can be initialized with None storage""" - self.assertIsNone(self.strategy.storage) + with self.assertRaises(StrategyMissingBackendError): + self.assertIsNone(self.strategy.storage) def test_create_user_raises_error(self) -> None: """Test that create_user raises StrategyMissingBackendError with None storage""" @@ -44,11 +48,8 @@ def test_clean_partial_pipeline_raises_error(self) -> None: def test_send_email_validation_raises_error(self) -> None: """Test that send_email_validation raises StrategyMissingBackendError with None storage""" backend = BaseAuth(self.strategy) - with self.assertRaises(StrategyMissingBackendError) as cm: + with self.assertRaises(SocialAuthImproperlyConfiguredError): self.strategy.send_email_validation(backend, "test@example.com") - self.assertEqual( - str(cm.exception), "Strategy storage backend is not configured" - ) def test_validate_email_raises_error(self) -> None: """Test that validate_email raises StrategyMissingBackendError with None storage"""