Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion social_core/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class BaseAuth:
def __init__(
self, strategy: BaseStrategy | None = None, redirect_uri: str | None = None
) -> None:
self.strategy: Any = (
self.strategy: BaseStrategy = (
strategy if strategy is not None else REGISTRY.default_strategy
)
self.redirect_uri = redirect_uri
Expand Down
2 changes: 1 addition & 1 deletion social_core/backends/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ def get_user_details(self, response):

def user_data(self, access_token: str, *args, **kwargs) -> dict[str, Any] | None:
"""Loads user data from service"""
params = self.setting("PROFILE_EXTRA_PARAMS", {})
params = cast("dict", self.setting("PROFILE_EXTRA_PARAMS", {}))
params["access_token"] = access_token
return self.get_json("https://api.box.com/2.0/users/me", params=params)
9 changes: 6 additions & 3 deletions social_core/backends/cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
implementation and the standard OIDC implementation in Python Social Auth.
"""

from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast

from .open_id_connect import OpenIdConnectAuth

if TYPE_CHECKING:
from collections.abc import Iterable


class CASOpenIdConnectAuth(OpenIdConnectAuth):
"""
Expand All @@ -30,7 +33,7 @@ class CASOpenIdConnectAuth(OpenIdConnectAuth):
STATE_PARAMETER = True

def oidc_endpoint(self):
endpoint = self.setting("OIDC_ENDPOINT", self.OIDC_ENDPOINT)
endpoint = super().oidc_endpoint()
self.log_debug("endpoint: %s", endpoint)
return endpoint

Expand Down Expand Up @@ -58,7 +61,7 @@ def get_user_details(self, response):
}

def auth_allowed(self, response, details):
allow_groups = set(self.setting("ALLOW_GROUPS", set()))
allow_groups = set(cast("Iterable", self.setting("ALLOW_GROUPS", set())))
groups = set(response.get("groups", set()))
return (
super().auth_allowed(response, details)
Expand Down
6 changes: 3 additions & 3 deletions social_core/backends/egi_checkin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from __future__ import annotations

from typing import Literal
from typing import Literal, cast

from social_core.backends.open_id_connect import OpenIdConnectAuth

CHECKIN_ENV_ENDPOINTS = {
CHECKIN_ENV_ENDPOINTS: dict[str, str] = {
"prod": "https://aai.egi.eu/auth/realms/egi",
"demo": "https://aai-demo.egi.eu/auth/realms/egi",
"dev": "https://aai-dev.egi.eu/auth/realms/egi",
Expand Down Expand Up @@ -48,7 +48,7 @@ def oidc_endpoint(self):
endpoint = self.setting("OIDC_ENDPOINT", self.OIDC_ENDPOINT)
if endpoint:
return endpoint
checkin_env = self.setting("CHECKIN_ENV", self.CHECKIN_ENV)
checkin_env = cast("str", self.setting("CHECKIN_ENV", self.CHECKIN_ENV))
return CHECKIN_ENV_ENDPOINTS.get(checkin_env, "")

def get_user_details(self, response):
Expand Down
6 changes: 3 additions & 3 deletions social_core/backends/facebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import hmac
import json
import time
from typing import Any
from typing import Any, cast

from social_core.exceptions import (
AuthCanceled,
Expand Down Expand Up @@ -49,10 +49,10 @@ def auth_params(self, state=None):
return params

def get_authorization_url_format(self) -> dict[str, str]:
return {"version": self.setting("API_VERSION", API_VERSION)}
return {"version": cast("str", self.setting("API_VERSION", API_VERSION))}

def get_access_token_url_format(self) -> dict[str, str]:
return {"version": self.setting("API_VERSION", API_VERSION)}
return {"version": cast("str", self.setting("API_VERSION", API_VERSION))}

def get_user_details(self, response):
"""Return user details from Facebook account"""
Expand Down
4 changes: 2 additions & 2 deletions social_core/backends/keycloak.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, cast

import jwt

Expand Down Expand Up @@ -123,7 +123,7 @@ def user_data(self, access_token: str, *args, **kwargs) -> dict[str, Any] | None
key=self.public_key(),
algorithms=self.algorithm(),
audience=self.audience(),
leeway=self.setting("JWT_LEEWAY", default=0),
leeway=cast("int", self.setting("JWT_LEEWAY", default=0)),
)

def get_user_details(self, response):
Expand Down
16 changes: 10 additions & 6 deletions social_core/backends/legacy.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from typing import cast

from social_core.exceptions import AuthMissingParameter

from .base import BaseAuth


class LegacyAuth(BaseAuth):
def auth_url(self):
return self.setting("FORM_URL")
def auth_url(self) -> str:
return cast("str", self.setting("FORM_URL"))

def auth_html(self):
return self.strategy.render_html(tpl=self.setting("FORM_HTML"))
def auth_html(self) -> str:
return self.strategy.render_html(tpl=cast("str", self.setting("FORM_HTML")))

def uses_redirect(self):
return self.setting("FORM_URL") and not self.setting("FORM_HTML")
def uses_redirect(self) -> bool:
return bool(cast("str", self.setting("FORM_URL"))) and not cast(
"str", self.setting("FORM_HTML")
)

def auth_complete(self, *args, **kwargs):
"""Completes login process, must return user instance"""
Expand Down
13 changes: 10 additions & 3 deletions social_core/backends/linkedin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import datetime
from calendar import timegm
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, cast

from social_core.backends.open_id_connect import OpenIdConnectAuth
from social_core.exceptions import AuthCanceled, AuthTokenError
Expand Down Expand Up @@ -75,7 +75,12 @@ class LinkedinOAuth2(BaseOAuth2):
def user_details_url(self):
# use set() since LinkedIn fails when values are duplicated
fields_selectors = list(
{"id", "firstName", "lastName", *self.setting("FIELD_SELECTORS", [])}
{
"id",
"firstName",
"lastName",
*cast("list[str]", self.setting("FIELD_SELECTORS", [])),
}
)
# user sort to ease the tests URL mocking
fields_selectors.sort()
Expand All @@ -90,7 +95,9 @@ def user_data(self, access_token: str, *args, **kwargs) -> dict[str, Any] | None
self.user_details_url(), headers=self.user_data_headers(access_token)
)

if "emailAddress" in set(self.setting("FIELD_SELECTORS", [])):
if "emailAddress" in set(
cast("list[str]", self.setting("FIELD_SELECTORS", []))
):
emails = self.email_data(access_token, *args, **kwargs)
if emails:
response["emailAddress"] = emails[0]
Expand Down
7 changes: 4 additions & 3 deletions social_core/backends/mediawiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import re
import time
from typing import cast
from urllib.parse import parse_qs, urlencode, urlparse

import jwt
Expand Down Expand Up @@ -46,7 +47,7 @@ def unauthorized_token(self):
params["title"] = "Special:OAuth/initiate"
key, secret = self.get_key_and_secret()
response = self.request(
self.setting("MEDIAWIKI_URL"),
cast("str", self.setting("MEDIAWIKI_URL")),
params=params,
auth=OAuth1(key, secret, callback_uri=self.setting("CALLBACK")),
method=self.REQUEST_TOKEN_METHOD,
Expand Down Expand Up @@ -83,7 +84,7 @@ def access_token(self, token):
auth_token = self.oauth_auth(token)

response = self.request(
self.setting("MEDIAWIKI_URL"),
cast("str", self.setting("MEDIAWIKI_URL")),
method="POST",
params={"title": "Special:Oauth/token"},
auth=auth_token,
Expand Down Expand Up @@ -116,7 +117,7 @@ def get_user_details(self, response):
)

req_resp = self.request(
self.setting("MEDIAWIKI_URL"),
cast("str", self.setting("MEDIAWIKI_URL")),
method="POST",
params={"title": "Special:OAuth/identify"},
auth=auth,
Expand Down
6 changes: 3 additions & 3 deletions social_core/backends/mineid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, cast

from .oauth import BaseOAuth2

Expand Down Expand Up @@ -31,6 +31,6 @@ def get_access_token_url_format(self) -> dict[str, str]:

def get_mineid_url_params(self) -> dict[str, str]:
return {
"host": self.setting("HOST", "www.mineid.org"),
"scheme": self.setting("SCHEME", "https"),
"host": cast("str", self.setting("HOST", "www.mineid.org")),
"scheme": cast("str", self.setting("SCHEME", "https")),
}
6 changes: 3 additions & 3 deletions social_core/backends/nationbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
https://python-social-auth.readthedocs.io/en/latest/backends/nationbuilder.html
"""

from typing import Any
from typing import Any, cast

from .oauth import BaseOAuth2

Expand All @@ -25,8 +25,8 @@ def get_access_token_url_format(self) -> dict[str, str]:
return {"slug": self.slug}

@property
def slug(self):
return self.setting("SLUG")
def slug(self) -> str:
return cast("str", self.setting("SLUG"))

def get_user_details(self, response):
"""Return user details from Github account"""
Expand Down
2 changes: 1 addition & 1 deletion social_core/backends/ngpvan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ActionIDOpenID(OpenIdAuth):
URL = "https://accounts.ngpvan.com/Home/Xrds"
USERNAME_KEY = "email"

def get_ax_attributes(self):
def get_ax_attributes(self) -> list[tuple[str, str]]:
"""
Return the AX attributes that ActionID responds with, as well as the
user data result that it must map to.
Expand Down
31 changes: 19 additions & 12 deletions social_core/backends/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ def get_redirect_uri(self, state: str | None = None) -> str:
uri = url_add_parameters(uri, {"redirect_state": state})
return uri

def get_scope(self):
def get_scope(self) -> list[str]:
"""Return list with needed access scope"""
scope = self.setting("SCOPE", [])
scope = cast("list[str]", self.setting("SCOPE", []))
if not self.setting("IGNORE_DEFAULT_SCOPE", False):
scope = scope + (self.DEFAULT_SCOPE or [])
return scope
Expand All @@ -147,7 +147,7 @@ def user_data(self, access_token, *args, **kwargs) -> dict[str, Any] | None:
raise NotImplementedError

def authorization_url(self) -> str:
url = self.setting("AUTHORIZATION_URL", self.AUTHORIZATION_URL)
url = cast("str", self.setting("AUTHORIZATION_URL", self.AUTHORIZATION_URL))
if format_params := self.get_authorization_url_format():
return url.format(**format_params)
return url
Expand All @@ -156,7 +156,7 @@ def get_authorization_url_format(self) -> dict[str, str]:
return {}

def access_token_url(self) -> str:
url = self.setting("ACCESS_TOKEN_URL", self.ACCESS_TOKEN_URL)
url = cast("str", self.setting("ACCESS_TOKEN_URL", self.ACCESS_TOKEN_URL))
if format_params := self.get_access_token_url_format():
return url.format(**format_params)
return url
Expand All @@ -165,7 +165,7 @@ def get_access_token_url_format(self) -> dict[str, str]:
return {}

def revoke_token_url(self, token, uid) -> str:
return self.setting("REVOKE_TOKEN_URL", self.REVOKE_TOKEN_URL)
return cast("str", self.setting("REVOKE_TOKEN_URL", self.REVOKE_TOKEN_URL))

def revoke_token_params(self, token, uid) -> dict[str, Any]:
return {}
Expand Down Expand Up @@ -272,9 +272,9 @@ def set_unauthorized_token(self):
self.strategy.session_set(name, tokens)
return token

def request_token_extra_arguments(self):
def request_token_extra_arguments(self) -> dict[str, str]:
"""Return extra arguments needed on request-token process"""
return self.setting("REQUEST_TOKEN_EXTRA_ARGUMENTS", {})
return cast("dict[str, str]", self.setting("REQUEST_TOKEN_EXTRA_ARGUMENTS", {}))

def unauthorized_token(self):
"""Return request for unauthorized token (first stage)"""
Expand Down Expand Up @@ -564,8 +564,12 @@ class BaseOAuth2PKCE(BaseOAuth2):

def create_code_verifier(self):
name = f"{self.name}_code_verifier"
code_verifier_len = self.setting(
"PKCE_CODE_VERIFIER_LENGTH", default=self.PKCE_DEFAULT_CODE_VERIFIER_LENGTH
code_verifier_len = cast(
"int",
self.setting(
"PKCE_CODE_VERIFIER_LENGTH",
default=self.PKCE_DEFAULT_CODE_VERIFIER_LENGTH,
),
)
code_verifier = self.strategy.random_string(code_verifier_len)
self.strategy.session_set(name, code_verifier)
Expand All @@ -589,9 +593,12 @@ def auth_params(self, state=None):
params = super().auth_params(state=state)

if self.setting("USE_PKCE", default=self.DEFAULT_USE_PKCE):
code_challenge_method = self.setting(
"PKCE_CODE_CHALLENGE_METHOD",
default=self.PKCE_DEFAULT_CODE_CHALLENGE_METHOD,
code_challenge_method = cast(
"str",
self.setting(
"PKCE_CODE_CHALLENGE_METHOD",
default=self.PKCE_DEFAULT_CODE_CHALLENGE_METHOD,
),
)
code_verifier = self.create_code_verifier()
code_challenge = self.generate_code_challenge(
Expand Down
23 changes: 13 additions & 10 deletions social_core/backends/odnoklassniki.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from hashlib import md5
from typing import Any
from typing import Any, cast
from urllib.parse import unquote

from social_core.exceptions import AuthFailed
Expand Down Expand Up @@ -102,7 +102,7 @@ def auth_complete(self, *args, **kwargs):
"first_name",
"last_name",
"name",
*self.setting("EXTRA_USER_DATA_LIST", ()),
*cast("tuple[str, ...]", self.setting("EXTRA_USER_DATA_LIST", ())),
)
data = {
"method": "users.getInfo",
Expand All @@ -121,14 +121,17 @@ def auth_complete(self, *args, **kwargs):
)
if len(details) == 1 and "uid" in details[0]:
details = details[0]
auth_data_fields = self.setting(
"EXTRA_AUTH_DATA_LIST",
(
"api_server",
"apiconnection",
"session_key",
"authorized",
"session_secret_key",
auth_data_fields = cast(
"tuple[str, ...]",
self.setting(
"EXTRA_AUTH_DATA_LIST",
(
"api_server",
"apiconnection",
"session_key",
"authorized",
"session_secret_key",
),
),
)

Expand Down
Loading
Loading