Skip to content

Commit 7985fca

Browse files
committed
chore: improve type annotations
This is needed for proper type annotations in #1491, see also #1490.
1 parent 12deaaa commit 7985fca

File tree

19 files changed

+185
-97
lines changed

19 files changed

+185
-97
lines changed

social_core/actions.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
)
1313

1414
if TYPE_CHECKING:
15+
from collections.abc import Callable
16+
1517
from .backends.base import BaseAuth
16-
from .storage import UserProtocol
18+
from .storage import BaseStorage, UserProtocol
1719
from .strategy import HttpResponseProtocol
1820

1921

@@ -22,7 +24,9 @@ def do_auth(backend: BaseAuth, redirect_name: str = "next") -> HttpResponseProto
2224
data = backend.strategy.request_data(merge=False)
2325

2426
# Save extra data into session.
25-
for field_name in backend.setting("FIELDS_STORED_IN_SESSION", []):
27+
for field_name in cast(
28+
"list[str]", backend.setting("FIELDS_STORED_IN_SESSION", [])
29+
):
2630
if field_name in data:
2731
backend.strategy.session_set(field_name, data[field_name])
2832
else:
@@ -33,7 +37,7 @@ def do_auth(backend: BaseAuth, redirect_name: str = "next") -> HttpResponseProto
3337
redirect_uri = data[redirect_name]
3438
if backend.setting("SANITIZE_REDIRECTS", True):
3539
allowed_hosts = [
36-
*backend.setting("ALLOWED_REDIRECT_HOSTS", []),
40+
*cast("list[str]", backend.setting("ALLOWED_REDIRECT_HOSTS", [])),
3741
backend.strategy.request_host(),
3842
]
3943
redirect_uri = sanitize_redirect(allowed_hosts, redirect_uri)
@@ -43,9 +47,9 @@ def do_auth(backend: BaseAuth, redirect_name: str = "next") -> HttpResponseProto
4347
return backend.start()
4448

4549

46-
def do_complete(
50+
def do_complete( # noqa: C901,PLR0912
4751
backend: BaseAuth,
48-
login,
52+
login: Callable,
4953
user: UserProtocol | None = None,
5054
redirect_name: str = "next",
5155
*args,
@@ -54,15 +58,19 @@ def do_complete(
5458
data = backend.strategy.request_data()
5559

5660
is_authenticated = user_is_authenticated(user)
57-
user = user if is_authenticated else None
61+
authenticated_user: UserProtocol | HttpResponseProtocol | None = (
62+
user if is_authenticated else None
63+
)
5864

5965
partial = partial_pipeline_data(backend, user, *args, **kwargs)
6066
if partial:
61-
user = backend.continue_pipeline(partial)
67+
authenticated_user = backend.continue_pipeline(partial)
6268
# clean partial data after usage
6369
backend.strategy.clean_partial_pipeline(partial.token)
6470
else:
65-
user = backend.complete(*args, user=user, redirect_name=redirect_name, **kwargs)
71+
authenticated_user = backend.complete(
72+
*args, user=authenticated_user, redirect_name=redirect_name, **kwargs
73+
)
6674

6775
# pop redirect value before the session is trashed on login(), but after
6876
# the pipeline so that the pipeline can change the redirect if needed
@@ -72,12 +80,15 @@ def do_complete(
7280

7381
# check if the output value is something else than a user and just
7482
# return it to the client
75-
user_model = backend.strategy.storage.user.user_model()
76-
if user and not isinstance(user, user_model):
77-
return cast("HttpResponseProtocol", user)
83+
user_model = cast("type[BaseStorage]", backend.strategy.storage).user.user_model()
84+
if authenticated_user and not isinstance(authenticated_user, user_model):
85+
return cast("HttpResponseProtocol", authenticated_user)
86+
87+
authenticated_user = cast("UserProtocol | None", authenticated_user)
88+
url: str | None
7889

7990
if is_authenticated:
80-
if not user:
91+
if not authenticated_user:
8192
url = setting_url(backend, redirect_value, "LOGIN_REDIRECT_URL")
8293
else:
8394
url = setting_url(
@@ -86,17 +97,17 @@ def do_complete(
8697
"NEW_ASSOCIATION_REDIRECT_URL",
8798
"LOGIN_REDIRECT_URL",
8899
)
89-
elif user:
100+
elif authenticated_user:
90101
# check if inactive users are allowed to login
91102
bypass_inactivation = backend.strategy.setting(
92103
"ALLOW_INACTIVE_USERS_LOGIN", False
93104
)
94-
if bypass_inactivation or user_is_active(user):
105+
if bypass_inactivation or user_is_active(authenticated_user):
95106
# catch is_new/social_user in case login() resets the instance
96107
# These attributes are set in BaseAuth.pipeline()
97-
is_new = getattr(user, "is_new", False)
98-
social_user = user.social_user # type: ignore[union-attr]
99-
login(backend, user, social_user)
108+
is_new = getattr(authenticated_user, "is_new", False)
109+
social_user = authenticated_user.social_user # type: ignore[union-attr]
110+
login(backend, authenticated_user, social_user)
100111
# store last login backend name in session
101112
backend.strategy.session_set(
102113
"social_auth_last_login_backend", social_user.provider
@@ -114,28 +125,31 @@ def do_complete(
114125
else:
115126
if backend.setting("INACTIVE_USER_LOGIN", False):
116127
# This attribute is set in BaseAuth.pipeline()
117-
social_user = user.social_user # type: ignore[union-attr]
118-
login(backend, user, social_user)
128+
social_user = authenticated_user.social_user # type: ignore[union-attr]
129+
login(backend, authenticated_user, social_user)
119130
url = setting_url(
120131
backend, "INACTIVE_USER_URL", "LOGIN_ERROR_URL", "LOGIN_URL"
121132
)
122133
else:
123134
url = setting_url(backend, "LOGIN_ERROR_URL", "LOGIN_URL")
124135

125-
assert url, "By this point URL has to have been set"
136+
if not url:
137+
raise ValueError("By this point URL has to have been set")
126138

127139
if redirect_value and redirect_value != url:
128140
redirect_value = quote(redirect_value)
129141
url += ("&" if "?" in url else "?") + f"{redirect_name}={redirect_value}"
130142

131143
if backend.setting("SANITIZE_REDIRECTS", True):
132144
allowed_hosts = [
133-
*backend.setting("ALLOWED_REDIRECT_HOSTS", []),
145+
*cast("list[str]", backend.setting("ALLOWED_REDIRECT_HOSTS", [])),
134146
backend.strategy.request_host(),
135147
]
136148
url = sanitize_redirect(allowed_hosts, url) or backend.setting(
137149
"LOGIN_REDIRECT_URL"
138150
)
151+
if url is None:
152+
raise ValueError("Disallowed URL")
139153
return backend.strategy.redirect(url)
140154

141155

@@ -160,20 +174,22 @@ def do_disconnect(
160174
)
161175

162176
if isinstance(response, dict):
163-
url = backend.strategy.absolute_uri(
177+
url: str | None = backend.strategy.absolute_uri(
164178
backend.strategy.request_data().get(redirect_name, "")
165179
or backend.setting("DISCONNECT_REDIRECT_URL")
166180
or backend.setting("LOGIN_REDIRECT_URL")
167181
)
168182
if backend.setting("SANITIZE_REDIRECTS", True):
169183
allowed_hosts = [
170-
*backend.setting("ALLOWED_REDIRECT_HOSTS", []),
184+
*cast("list[str]", backend.setting("ALLOWED_REDIRECT_HOSTS", [])),
171185
backend.strategy.request_host(),
172186
]
173187
url = (
174188
sanitize_redirect(allowed_hosts, url)
175189
or backend.setting("DISCONNECT_REDIRECT_URL")
176190
or backend.setting("LOGIN_REDIRECT_URL")
177191
)
192+
if not url:
193+
raise ValueError("Disallowed URL")
178194
response = backend.strategy.redirect(url)
179195
return response

social_core/backends/apple.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import json
2727
import time
28-
from typing import TYPE_CHECKING
28+
from typing import TYPE_CHECKING, cast
2929

3030
import jwt
3131
from jwt.algorithms import RSAAlgorithm # ty: ignore[possibly-missing-import]
@@ -71,12 +71,12 @@ def auth_params(self, *args, **kwargs):
7171
params["response_mode"] = "form_post"
7272
return params
7373

74-
def get_private_key(self):
74+
def get_private_key(self) -> str:
7575
"""
7676
Return contents of the private key file. Override this method to provide
7777
secret key from another source if needed.
7878
"""
79-
return self.setting("SECRET")
79+
return cast("str", self.setting("SECRET"))
8080

8181
def generate_client_secret(self):
8282
now = int(time.time())

social_core/backends/azuread_b2c.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from __future__ import annotations
3131

3232
import json
33-
from typing import TYPE_CHECKING, Any, Literal
33+
from typing import TYPE_CHECKING, Any, Literal, cast
3434

3535
from cryptography.hazmat.primitives import serialization
3636
from jwt import DecodeError, ExpiredSignatureError, get_unverified_header
@@ -198,7 +198,7 @@ def user_data(self, access_token: str, *args, **kwargs) -> dict[str, Any] | None
198198
key=key,
199199
algorithms=["RS256"],
200200
audience=self.setting("KEY"),
201-
leeway=self.setting("JWT_LEEWAY", default=0),
201+
leeway=cast("int", self.setting("JWT_LEEWAY", default=0)),
202202
)
203203
except (DecodeError, ExpiredSignatureError) as error:
204204
raise AuthTokenError(self, error)

social_core/backends/azuread_tenant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import base64
2-
from typing import Any
2+
from typing import Any, cast
33

44
from cryptography.hazmat.backends import default_backend
55
from cryptography.x509 import load_der_x509_certificate
@@ -51,8 +51,8 @@ class AzureADTenantOAuth2(AzureADOAuth2):
5151
JWKS_URL = "{base_url}/discovery/keys{appid}"
5252

5353
@property
54-
def tenant_id(self):
55-
return self.setting("TENANT_ID", "common")
54+
def tenant_id(self) -> str:
55+
return cast("str", self.setting("TENANT_ID", "common"))
5656

5757
def openid_configuration_url(self):
5858
return self.OPENID_CONFIGURATION_URL.format(

social_core/backends/base.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from social_core.exceptions import AuthConnectionError, AuthUnknownError
1010
from social_core.registry import REGISTRY
11-
from social_core.storage import UserProtocol
11+
from social_core.storage import BaseStorage, UserProtocol
1212
from social_core.utils import module_member, parse_qs, social_logger, user_agent
1313

1414
if TYPE_CHECKING:
@@ -17,8 +17,8 @@
1717
from requests import Response
1818
from requests.auth import AuthBase
1919

20-
from social_core.storage import UserProtocol
21-
from social_core.strategy import HttpResponseProtocol
20+
from social_core.storage import BaseStorage, PartialMixin, UserProtocol
21+
from social_core.strategy import BaseStrategy, HttpResponseProtocol
2222

2323

2424
class BaseAuth:
@@ -33,8 +33,9 @@ class BaseAuth:
3333
REQUIRES_EMAIL_VALIDATION = False
3434
SEND_USER_AGENT = True
3535

36-
def __init__(self, strategy=None, redirect_uri: str | None = None) -> None:
37-
# TODO: temporary type override
36+
def __init__(
37+
self, strategy: BaseStrategy | None = None, redirect_uri: str | None = None
38+
) -> None:
3839
self.strategy: Any = (
3940
strategy if strategy is not None else REGISTRY.default_strategy
4041
)
@@ -57,7 +58,7 @@ def start(self) -> HttpResponseProtocol:
5758
return self.strategy.redirect(self.auth_url())
5859
return self.strategy.html(self.auth_html())
5960

60-
def complete(self, *args, **kwargs) -> UserProtocol | None:
61+
def complete(self, *args, **kwargs) -> HttpResponseProtocol | UserProtocol | None:
6162
return self.auth_complete(*args, **kwargs)
6263

6364
def auth_url(self) -> str:
@@ -68,7 +69,9 @@ def auth_html(self) -> str:
6869
"""Must return login HTML content returned by provider"""
6970
return "Implement in subclass"
7071

71-
def auth_complete(self, *args, **kwargs) -> UserProtocol | None:
72+
def auth_complete(
73+
self, *args, **kwargs
74+
) -> HttpResponseProtocol | UserProtocol | None:
7275
"""Completes login process, must return user instance"""
7376
raise NotImplementedError("Implement in subclass")
7477

@@ -120,7 +123,7 @@ def pipeline(
120123
def disconnect(self, *args, **kwargs) -> dict:
121124
pipeline = self.strategy.get_disconnect_pipeline(self)
122125
kwargs["name"] = self.name
123-
kwargs["user_storage"] = self.strategy.storage.user
126+
kwargs["user_storage"] = cast("type[BaseStorage]", self.strategy.storage).user
124127
return self.run_pipeline(pipeline, *args, **kwargs)
125128

126129
def run_pipeline(
@@ -194,8 +197,14 @@ def extra_data(
194197
def auth_allowed(self, response, details):
195198
"""Return True if the user should be allowed to authenticate, by
196199
default check if email is whitelisted (if there's a whitelist)"""
197-
emails = [email.lower() for email in self.setting("WHITELISTED_EMAILS", [])]
198-
domains = [domain.lower() for domain in self.setting("WHITELISTED_DOMAINS", [])]
200+
emails = [
201+
email.lower()
202+
for email in cast("list[str]", self.setting("WHITELISTED_EMAILS", []))
203+
]
204+
domains = [
205+
domain.lower()
206+
for domain in cast("list[str]", self.setting("WHITELISTED_DOMAINS", []))
207+
]
199208
email = details.get("email")
200209
allowed = True
201210
if email and (emails or domains):
@@ -249,7 +258,9 @@ def get_user(self, user_id):
249258
"""
250259
return self.strategy.get_user(user_id)
251260

252-
def continue_pipeline(self, partial):
261+
def continue_pipeline(
262+
self, partial: PartialMixin
263+
) -> UserProtocol | HttpResponseProtocol | None:
253264
"""Continue previous halted pipeline"""
254265
return self.strategy.authenticate(
255266
self, *partial.args, pipeline_index=partial.next_step, **partial.kwargs
@@ -336,7 +347,7 @@ def get_key_and_secret(self) -> tuple[str, str]:
336347
"""Return tuple with Consumer Key and Consumer Secret for current
337348
service provider. Must return (key, secret), order *must* be respected.
338349
"""
339-
return self.setting("KEY"), self.setting("SECRET")
350+
return cast("str", self.setting("KEY")), cast("str", self.setting("SECRET"))
340351

341352
def get_key_and_secret_basic_auth(self) -> bytes:
342353
"""Generate HTTP Basic Authentication header value from KEY and SECRET.

social_core/backends/discourse.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22
import time
33
from base64 import urlsafe_b64decode, urlsafe_b64encode
44
from hashlib import sha256
5+
from typing import TYPE_CHECKING, cast
56
from urllib.parse import urlencode
67

78
from social_core.exceptions import AuthException, AuthTokenError
89
from social_core.utils import parse_qs
910

1011
from .base import BaseAuth
1112

13+
if TYPE_CHECKING:
14+
from social_core.storage import BaseStorage
15+
1216

1317
class DiscourseAuth(BaseAuth):
1418
name = "discourse"
@@ -51,13 +55,17 @@ def get_user_details(self, response):
5155
}
5256

5357
def add_nonce(self, nonce) -> None:
54-
self.strategy.storage.nonce.use(self.setting("SERVER_URL"), time.time(), nonce)
58+
cast("type[BaseStorage]", self.strategy.storage).nonce.use(
59+
self.setting("SERVER_URL"), time.time(), nonce
60+
)
5561

5662
def get_nonce(self, nonce):
57-
return self.strategy.storage.nonce.get(self.setting("SERVER_URL"), nonce)
63+
return cast("type[BaseStorage]", self.strategy.storage).nonce.get(
64+
self.setting("SERVER_URL"), nonce
65+
)
5866

5967
def delete_nonce(self, nonce) -> None:
60-
self.strategy.storage.nonce.delete(nonce)
68+
cast("type[BaseStorage]", self.strategy.storage).nonce.delete(nonce)
6169

6270
def auth_complete(self, *args, **kwargs):
6371
"""

0 commit comments

Comments
 (0)