Skip to content

Commit f927986

Browse files
committed
chore: fix type annotations
Properly annotate strategy and deal with the mess.
1 parent 06921a6 commit f927986

28 files changed

+135
-90
lines changed

social_core/backends/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class BaseAuth:
3535
def __init__(
3636
self, strategy: BaseStrategy | None = None, redirect_uri: str | None = None
3737
) -> None:
38-
self.strategy: Any = (
38+
self.strategy: BaseStrategy = (
3939
strategy if strategy is not None else REGISTRY.default_strategy
4040
)
4141
self.redirect_uri = redirect_uri

social_core/backends/box.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ def get_user_details(self, response):
4646

4747
def user_data(self, access_token: str, *args, **kwargs) -> dict[str, Any] | None:
4848
"""Loads user data from service"""
49-
params = self.setting("PROFILE_EXTRA_PARAMS", {})
49+
params = cast("dict", self.setting("PROFILE_EXTRA_PARAMS", {}))
5050
params["access_token"] = access_token
5151
return self.get_json("https://api.box.com/2.0/users/me", params=params)

social_core/backends/cas.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
implementation and the standard OIDC implementation in Python Social Auth.
88
"""
99

10-
from typing import Any, cast
10+
from typing import TYPE_CHECKING, Any, cast
1111

1212
from .open_id_connect import OpenIdConnectAuth
1313

14+
if TYPE_CHECKING:
15+
from collections.abc import Iterable
16+
1417

1518
class CASOpenIdConnectAuth(OpenIdConnectAuth):
1619
"""
@@ -30,7 +33,7 @@ class CASOpenIdConnectAuth(OpenIdConnectAuth):
3033
STATE_PARAMETER = True
3134

3235
def oidc_endpoint(self):
33-
endpoint = self.setting("OIDC_ENDPOINT", self.OIDC_ENDPOINT)
36+
endpoint = super().oidc_endpoint()
3437
self.log_debug("endpoint: %s", endpoint)
3538
return endpoint
3639

@@ -58,7 +61,7 @@ def get_user_details(self, response):
5861
}
5962

6063
def auth_allowed(self, response, details):
61-
allow_groups = set(self.setting("ALLOW_GROUPS", set()))
64+
allow_groups = set(cast("Iterable", self.setting("ALLOW_GROUPS", set())))
6265
groups = set(response.get("groups", set()))
6366
return (
6467
super().auth_allowed(response, details)

social_core/backends/egi_checkin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
from __future__ import annotations
77

8-
from typing import Literal
8+
from typing import Literal, cast
99

1010
from social_core.backends.open_id_connect import OpenIdConnectAuth
1111

12-
CHECKIN_ENV_ENDPOINTS = {
12+
CHECKIN_ENV_ENDPOINTS: dict[str, str] = {
1313
"prod": "https://aai.egi.eu/auth/realms/egi",
1414
"demo": "https://aai-demo.egi.eu/auth/realms/egi",
1515
"dev": "https://aai-dev.egi.eu/auth/realms/egi",
@@ -48,7 +48,7 @@ def oidc_endpoint(self):
4848
endpoint = self.setting("OIDC_ENDPOINT", self.OIDC_ENDPOINT)
4949
if endpoint:
5050
return endpoint
51-
checkin_env = self.setting("CHECKIN_ENV", self.CHECKIN_ENV)
51+
checkin_env = cast("str", self.setting("CHECKIN_ENV", self.CHECKIN_ENV))
5252
return CHECKIN_ENV_ENDPOINTS.get(checkin_env, "")
5353

5454
def get_user_details(self, response):

social_core/backends/facebook.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import hmac
99
import json
1010
import time
11-
from typing import Any
11+
from typing import Any, cast
1212

1313
from social_core.exceptions import (
1414
AuthCanceled,
@@ -49,10 +49,10 @@ def auth_params(self, state=None):
4949
return params
5050

5151
def get_authorization_url_format(self) -> dict[str, str]:
52-
return {"version": self.setting("API_VERSION", API_VERSION)}
52+
return {"version": cast("str", self.setting("API_VERSION", API_VERSION))}
5353

5454
def get_access_token_url_format(self) -> dict[str, str]:
55-
return {"version": self.setting("API_VERSION", API_VERSION)}
55+
return {"version": cast("str", self.setting("API_VERSION", API_VERSION))}
5656

5757
def get_user_details(self, response):
5858
"""Return user details from Facebook account"""

social_core/backends/keycloak.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, cast
22

33
import jwt
44

@@ -123,7 +123,7 @@ def user_data(self, access_token: str, *args, **kwargs) -> dict[str, Any] | None
123123
key=self.public_key(),
124124
algorithms=self.algorithm(),
125125
audience=self.audience(),
126-
leeway=self.setting("JWT_LEEWAY", default=0),
126+
leeway=cast("int", self.setting("JWT_LEEWAY", default=0)),
127127
)
128128

129129
def get_user_details(self, response):

social_core/backends/legacy.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
1+
from typing import cast
2+
13
from social_core.exceptions import AuthMissingParameter
24

35
from .base import BaseAuth
46

57

68
class LegacyAuth(BaseAuth):
7-
def auth_url(self):
8-
return self.setting("FORM_URL")
9+
def auth_url(self) -> str:
10+
return cast("str", self.setting("FORM_URL"))
911

10-
def auth_html(self):
11-
return self.strategy.render_html(tpl=self.setting("FORM_HTML"))
12+
def auth_html(self) -> str:
13+
return self.strategy.render_html(tpl=cast("str", self.setting("FORM_HTML")))
1214

13-
def uses_redirect(self):
14-
return self.setting("FORM_URL") and not self.setting("FORM_HTML")
15+
def uses_redirect(self) -> bool:
16+
return bool(cast("str", self.setting("FORM_URL"))) and not cast(
17+
"str", self.setting("FORM_HTML")
18+
)
1519

1620
def auth_complete(self, *args, **kwargs):
1721
"""Completes login process, must return user instance"""

social_core/backends/linkedin.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import datetime
99
from calendar import timegm
10-
from typing import TYPE_CHECKING, Any, Literal
10+
from typing import TYPE_CHECKING, Any, Literal, cast
1111

1212
from social_core.backends.open_id_connect import OpenIdConnectAuth
1313
from social_core.exceptions import AuthCanceled, AuthTokenError
@@ -75,7 +75,12 @@ class LinkedinOAuth2(BaseOAuth2):
7575
def user_details_url(self):
7676
# use set() since LinkedIn fails when values are duplicated
7777
fields_selectors = list(
78-
{"id", "firstName", "lastName", *self.setting("FIELD_SELECTORS", [])}
78+
{
79+
"id",
80+
"firstName",
81+
"lastName",
82+
*cast("list[str]", self.setting("FIELD_SELECTORS", [])),
83+
}
7984
)
8085
# user sort to ease the tests URL mocking
8186
fields_selectors.sort()
@@ -90,7 +95,9 @@ def user_data(self, access_token: str, *args, **kwargs) -> dict[str, Any] | None
9095
self.user_details_url(), headers=self.user_data_headers(access_token)
9196
)
9297

93-
if "emailAddress" in set(self.setting("FIELD_SELECTORS", [])):
98+
if "emailAddress" in set(
99+
cast("list[str]", self.setting("FIELD_SELECTORS", []))
100+
):
94101
emails = self.email_data(access_token, *args, **kwargs)
95102
if emails:
96103
response["emailAddress"] = emails[0]

social_core/backends/mediawiki.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import re
77
import time
8+
from typing import cast
89
from urllib.parse import parse_qs, urlencode, urlparse
910

1011
import jwt
@@ -46,7 +47,7 @@ def unauthorized_token(self):
4647
params["title"] = "Special:OAuth/initiate"
4748
key, secret = self.get_key_and_secret()
4849
response = self.request(
49-
self.setting("MEDIAWIKI_URL"),
50+
cast("str", self.setting("MEDIAWIKI_URL")),
5051
params=params,
5152
auth=OAuth1(key, secret, callback_uri=self.setting("CALLBACK")),
5253
method=self.REQUEST_TOKEN_METHOD,
@@ -83,7 +84,7 @@ def access_token(self, token):
8384
auth_token = self.oauth_auth(token)
8485

8586
response = self.request(
86-
self.setting("MEDIAWIKI_URL"),
87+
cast("str", self.setting("MEDIAWIKI_URL")),
8788
method="POST",
8889
params={"title": "Special:Oauth/token"},
8990
auth=auth_token,
@@ -116,7 +117,7 @@ def get_user_details(self, response):
116117
)
117118

118119
req_resp = self.request(
119-
self.setting("MEDIAWIKI_URL"),
120+
cast("str", self.setting("MEDIAWIKI_URL")),
120121
method="POST",
121122
params={"title": "Special:OAuth/identify"},
122123
auth=auth,

social_core/backends/mineid.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, cast
22

33
from .oauth import BaseOAuth2
44

@@ -31,6 +31,6 @@ def get_access_token_url_format(self) -> dict[str, str]:
3131

3232
def get_mineid_url_params(self) -> dict[str, str]:
3333
return {
34-
"host": self.setting("HOST", "www.mineid.org"),
35-
"scheme": self.setting("SCHEME", "https"),
34+
"host": cast("str", self.setting("HOST", "www.mineid.org")),
35+
"scheme": cast("str", self.setting("SCHEME", "https")),
3636
}

0 commit comments

Comments
 (0)