Skip to content

Commit 2d1fbf9

Browse files
committed
chore: improve type annotations
- Add type annotations - Add safeguards against some unexpected values - Fixed response type for the test backend and adjust related tests
1 parent 72f2ced commit 2d1fbf9

File tree

18 files changed

+165
-88
lines changed

18 files changed

+165
-88
lines changed

social_core/actions.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, cast
14
from urllib.parse import quote
25

36
from .utils import (
@@ -8,8 +11,13 @@
811
user_is_authenticated,
912
)
1013

14+
if TYPE_CHECKING:
15+
from .backends.base import BaseAuth
16+
from .storage import UserProtocol
17+
from .strategy import HttpResponseProtocol
18+
1119

12-
def do_auth(backend, redirect_name="next"):
20+
def do_auth(backend: BaseAuth, redirect_name: str = "next") -> HttpResponseProtocol:
1321
# Save any defined next value into session
1422
data = backend.strategy.request_data(merge=False)
1523

@@ -35,7 +43,14 @@ def do_auth(backend, redirect_name="next"):
3543
return backend.start()
3644

3745

38-
def do_complete(backend, login, user=None, redirect_name="next", *args, **kwargs):
46+
def do_complete(
47+
backend: BaseAuth,
48+
login,
49+
user: UserProtocol | None = None,
50+
redirect_name: str = "next",
51+
*args,
52+
**kwargs,
53+
) -> HttpResponseProtocol:
3954
data = backend.strategy.request_data()
4055

4156
is_authenticated = user_is_authenticated(user)
@@ -59,7 +74,7 @@ def do_complete(backend, login, user=None, redirect_name="next", *args, **kwargs
5974
# return it to the client
6075
user_model = backend.strategy.storage.user.user_model()
6176
if user and not isinstance(user, user_model):
62-
return user
77+
return cast("HttpResponseProtocol", user)
6378

6479
if is_authenticated:
6580
if not user:
@@ -78,8 +93,9 @@ def do_complete(backend, login, user=None, redirect_name="next", *args, **kwargs
7893
)
7994
if bypass_inactivation or user_is_active(user):
8095
# catch is_new/social_user in case login() resets the instance
96+
# These attributes are set in BaseAuth.pipeline()
8197
is_new = getattr(user, "is_new", False)
82-
social_user = user.social_user
98+
social_user = user.social_user # type: ignore[union-attr]
8399
login(backend, user, social_user)
84100
# store last login backend name in session
85101
backend.strategy.session_set(
@@ -97,7 +113,8 @@ def do_complete(backend, login, user=None, redirect_name="next", *args, **kwargs
97113
url = setting_url(backend, redirect_value, "LOGIN_REDIRECT_URL")
98114
else:
99115
if backend.setting("INACTIVE_USER_LOGIN", False):
100-
social_user = user.social_user
116+
# This attribute is set in BaseAuth.pipeline()
117+
social_user = user.social_user # type: ignore[union-attr]
101118
login(backend, user, social_user)
102119
url = setting_url(
103120
backend, "INACTIVE_USER_URL", "LOGIN_ERROR_URL", "LOGIN_URL"
@@ -123,7 +140,12 @@ def do_complete(backend, login, user=None, redirect_name="next", *args, **kwargs
123140

124141

125142
def do_disconnect(
126-
backend, user, association_id=None, redirect_name="next", *args, **kwargs
143+
backend: BaseAuth,
144+
user: UserProtocol,
145+
association_id=None,
146+
redirect_name: str = "next",
147+
*args,
148+
**kwargs,
127149
):
128150
partial = partial_pipeline_data(backend, user, *args, **kwargs)
129151
if partial:

social_core/backends/apple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from typing import TYPE_CHECKING
2929

3030
import jwt
31-
from jwt.algorithms import RSAAlgorithm # ty: ignore[possibly-unbound-import]
31+
from jwt.algorithms import RSAAlgorithm # ty: ignore[possibly-missing-import]
3232
from jwt.exceptions import PyJWTError
3333

3434
from social_core.backends.oauth import BaseOAuth2

social_core/backends/azuread_b2c.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from cryptography.hazmat.primitives import serialization
3636
from jwt import DecodeError, ExpiredSignatureError, get_unverified_header
3737
from jwt import decode as jwt_decode
38-
from jwt.algorithms import RSAAlgorithm # ty: ignore[possibly-unbound-import]
38+
from jwt.algorithms import RSAAlgorithm # ty: ignore[possibly-missing-import]
3939

4040
from social_core.exceptions import AuthException, AuthTokenError
4141

social_core/backends/base.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from requests.auth import AuthBase
1717

1818
from social_core.storage import UserProtocol
19+
from social_core.strategy import HttpResponseProtocol
1920

2021

2122
class BaseAuth:
@@ -46,12 +47,12 @@ def setting(self, name: str, default=None):
4647
"""Return setting value from strategy"""
4748
return self.strategy.setting(name, default=default, backend=self)
4849

49-
def start(self):
50+
def start(self) -> HttpResponseProtocol:
5051
if self.uses_redirect():
5152
return self.strategy.redirect(self.auth_url())
5253
return self.strategy.html(self.auth_html())
5354

54-
def complete(self, *args, **kwargs):
55+
def complete(self, *args, **kwargs) -> UserProtocol | None:
5556
return self.auth_complete(*args, **kwargs)
5657

5758
def auth_url(self) -> str:
@@ -62,15 +63,17 @@ def auth_html(self) -> str:
6263
"""Must return login HTML content returned by provider"""
6364
return "Implement in subclass"
6465

65-
def auth_complete(self, *args, **kwargs):
66+
def auth_complete(self, *args, **kwargs) -> UserProtocol | None:
6667
"""Completes login process, must return user instance"""
6768
raise NotImplementedError("Implement in subclass")
6869

6970
def process_error(self, data) -> None:
7071
"""Process data for errors, raise exception if needed.
7172
Call this method on any override of auth_complete."""
7273

73-
def authenticate(self, *args, **kwargs):
74+
def authenticate(
75+
self, *args, **kwargs
76+
) -> UserProtocol | HttpResponseProtocol | None:
7477
"""Authenticate user using social credentials
7578
7679
Authentication is made if this is the correct backend, backend
@@ -97,23 +100,27 @@ def authenticate(self, *args, **kwargs):
97100
args, kwargs = self.strategy.clean_authenticate_args(*args, **kwargs)
98101
return self.pipeline(pipeline, *args, **kwargs)
99102

100-
def pipeline(self, pipeline, pipeline_index=0, *args, **kwargs):
103+
def pipeline(
104+
self, pipeline, pipeline_index: int = 0, *args, **kwargs
105+
) -> UserProtocol | HttpResponseProtocol | None:
101106
out = self.run_pipeline(pipeline, pipeline_index, *args, **kwargs)
102107
if not isinstance(out, dict):
103-
return out
104-
user = out.get("user")
108+
return cast("HttpResponseProtocol", out)
109+
user = cast("UserProtocol | None", out.get("user"))
105110
if user:
106-
user.social_user = out.get("social")
107-
user.is_new = out.get("is_new")
111+
user.social_user = out.get("social") # type: ignore[attr-defined]
112+
user.is_new = out.get("is_new") # type: ignore[attr-defined]
108113
return user
109114

110-
def disconnect(self, *args, **kwargs):
115+
def disconnect(self, *args, **kwargs) -> dict:
111116
pipeline = self.strategy.get_disconnect_pipeline(self)
112117
kwargs["name"] = self.name
113118
kwargs["user_storage"] = self.strategy.storage.user
114119
return self.run_pipeline(pipeline, *args, **kwargs)
115120

116-
def run_pipeline(self, pipeline: list[str], pipeline_index=0, *args, **kwargs):
121+
def run_pipeline(
122+
self, pipeline: list[str], pipeline_index=0, *args, **kwargs
123+
) -> dict:
117124
out = kwargs.copy()
118125
out.setdefault("strategy", self.strategy)
119126
out.setdefault("backend", out.pop(self.name, None) or self)

social_core/backends/saml.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from __future__ import annotations
1212

1313
import json
14-
from typing import Any
14+
from typing import Any, cast
1515

1616
from onelogin.saml2.auth import OneLogin_Saml2_Auth
1717
from onelogin.saml2.settings import OneLogin_Saml2_Settings
@@ -318,8 +318,10 @@ def generate_saml_config(self, idp: SAMLIdentityProvider | None = None):
318318
},
319319
"strict": True, # We must force strict mode - for security
320320
}
321-
config["security"].update(self.setting("SECURITY_CONFIG", {}))
322-
config["sp"].update(self.setting("SP_EXTRA", {}))
321+
cast("dict", config["security"]).update(
322+
cast("dict", self.setting("SECURITY_CONFIG", {}))
323+
)
324+
cast("dict", config["sp"]).update(cast("dict", self.setting("SP_EXTRA", {})))
323325
return config
324326

325327
def generate_metadata_xml(self):

social_core/pipeline/user.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from typing import TYPE_CHECKING, cast
44
from uuid import uuid4
55

6+
from social_core.exceptions import (
7+
StrategyMissingBackendError,
8+
)
69
from social_core.utils import module_member, slugify
710

811
if TYPE_CHECKING:
@@ -21,6 +24,8 @@ def get_username(
2124
*args,
2225
**kwargs,
2326
):
27+
if strategy.storage is None:
28+
raise StrategyMissingBackendError
2429
if "username" not in backend.setting("USER_FIELDS", USER_FIELDS):
2530
return None
2631
storage = strategy.storage
@@ -110,6 +115,8 @@ def user_details(
110115
**kwargs,
111116
) -> None:
112117
"""Update user details using data from provider."""
118+
if strategy.storage is None:
119+
raise StrategyMissingBackendError
113120
if not user:
114121
return
115122

social_core/pipeline/utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22

33
from typing import TYPE_CHECKING
44

5+
from social_core.exceptions import (
6+
StrategyMissingBackendError,
7+
)
8+
59
if TYPE_CHECKING:
610
from social_core.backends.base import BaseAuth
7-
from social_core.storage import UserProtocol
11+
from social_core.storage import PartialMixin, UserProtocol
812
from social_core.strategy import BaseStrategy
913

1014
SERIALIZABLE_TYPES = (dict, list, tuple, set, bool, type(None), int, str, bytes)
@@ -28,7 +32,9 @@ def partial_prepare(
2832
social=None,
2933
*args,
3034
**kwargs,
31-
):
35+
) -> PartialMixin:
36+
if strategy.storage is None:
37+
raise StrategyMissingBackendError
3238
kwargs.update(
3339
{
3440
"response": kwargs.get("response") or {},
@@ -59,12 +65,16 @@ def partial_prepare(
5965

6066
def partial_store(
6167
strategy: BaseStrategy, backend: BaseAuth, next_step, *args, **kwargs
62-
):
68+
) -> PartialMixin:
69+
if strategy.storage is None:
70+
raise StrategyMissingBackendError
6371
partial = partial_prepare(strategy, backend, next_step, *args, **kwargs)
6472
return strategy.storage.partial.store(partial)
6573

6674

67-
def partial_load(strategy: BaseStrategy, token):
75+
def partial_load(strategy: BaseStrategy, token: str) -> PartialMixin | None:
76+
if strategy.storage is None:
77+
raise StrategyMissingBackendError
6878
partial = strategy.storage.partial.load(token)
6979

7080
if partial:

social_core/storage.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from .exceptions import InvalidExpiryValue, MissingBackend
1515

1616
if TYPE_CHECKING:
17+
from collections.abc import Callable
18+
1719
from social_core.backends.base import BaseAuth
1820
from social_core.strategy import BaseStrategy
1921

@@ -24,6 +26,12 @@
2426
class UserProtocol(Protocol):
2527
id: int
2628
username: str
29+
is_active: bool | Callable[[], bool]
30+
is_authenticated: bool | Callable[[], bool]
31+
32+
# Set in BaseAuth.pipeline
33+
# social_user: UserMixin
34+
# is_new: bool
2735

2836

2937
class UserMixin:
@@ -386,6 +394,9 @@ def args(self):
386394
def args(self, value) -> None:
387395
self.data["args"] = value
388396

397+
@abstractmethod
398+
def save(self): ...
399+
389400
@property
390401
def kwargs(self):
391402
return self.data.get("kwargs", {})
@@ -398,19 +409,21 @@ def extend_kwargs(self, values) -> None:
398409
self.data["kwargs"].update(values)
399410

400411
@classmethod
401-
def generate_token(cls):
412+
def generate_token(cls) -> str:
402413
return uuid.uuid4().hex
403414

404415
@classmethod
405-
def load(cls, token):
416+
def load(cls, token: str) -> PartialMixin | None:
406417
raise NotImplementedError("Implement in subclass")
407418

408419
@classmethod
409-
def destroy(cls, token):
420+
def destroy(cls, token: str):
410421
raise NotImplementedError("Implement in subclass")
411422

412423
@classmethod
413-
def prepare(cls, backend, next_step: int, data: dict[str, Any]):
424+
def prepare(
425+
cls, backend: str, next_step: int, data: dict[str, Any]
426+
) -> PartialMixin:
414427
partial = cls()
415428
partial.backend = backend
416429
partial.next_step = next_step
@@ -419,7 +432,7 @@ def prepare(cls, backend, next_step: int, data: dict[str, Any]):
419432
return partial
420433

421434
@classmethod
422-
def store(cls, partial):
435+
def store(cls, partial: PartialMixin) -> PartialMixin:
423436
partial.save()
424437
return partial
425438

0 commit comments

Comments
 (0)