1616 from requests .auth import AuthBase
1717
1818 from social_core .storage import UserProtocol
19+ from social_core .strategy import HttpResponseProtocol
1920
2021
2122class BaseAuth :
@@ -46,12 +47,12 @@ def setting(self, name: str, default=None):
4647 """Return setting value from strategy"""
4748 return self .strategy .setting (name , default = default , backend = self )
4849
49- def start (self ):
50+ def start (self ) -> HttpResponseProtocol :
5051 if self .uses_redirect ():
5152 return self .strategy .redirect (self .auth_url ())
5253 return self .strategy .html (self .auth_html ())
5354
54- def complete (self , * args , ** kwargs ):
55+ def complete (self , * args , ** kwargs ) -> UserProtocol | None :
5556 return self .auth_complete (* args , ** kwargs )
5657
5758 def auth_url (self ) -> str :
@@ -62,15 +63,17 @@ def auth_html(self) -> str:
6263 """Must return login HTML content returned by provider"""
6364 return "Implement in subclass"
6465
65- def auth_complete (self , * args , ** kwargs ):
66+ def auth_complete (self , * args , ** kwargs ) -> UserProtocol | None :
6667 """Completes login process, must return user instance"""
6768 raise NotImplementedError ("Implement in subclass" )
6869
6970 def process_error (self , data ) -> None :
7071 """Process data for errors, raise exception if needed.
7172 Call this method on any override of auth_complete."""
7273
73- def authenticate (self , * args , ** kwargs ):
74+ def authenticate (
75+ self , * args , ** kwargs
76+ ) -> UserProtocol | HttpResponseProtocol | None :
7477 """Authenticate user using social credentials
7578
7679 Authentication is made if this is the correct backend, backend
@@ -97,23 +100,27 @@ def authenticate(self, *args, **kwargs):
97100 args , kwargs = self .strategy .clean_authenticate_args (* args , ** kwargs )
98101 return self .pipeline (pipeline , * args , ** kwargs )
99102
100- def pipeline (self , pipeline , pipeline_index = 0 , * args , ** kwargs ):
103+ def pipeline (
104+ self , pipeline , pipeline_index : int = 0 , * args , ** kwargs
105+ ) -> UserProtocol | HttpResponseProtocol | None :
101106 out = self .run_pipeline (pipeline , pipeline_index , * args , ** kwargs )
102107 if not isinstance (out , dict ):
103- return out
104- user = out .get ("user" )
108+ return cast ( "HttpResponseProtocol" , out )
109+ user = cast ( "UserProtocol | None" , out .get ("user" ) )
105110 if user :
106- user .social_user = out .get ("social" )
107- user .is_new = out .get ("is_new" )
111+ user .social_user = out .get ("social" ) # type: ignore[attr-defined]
112+ user .is_new = out .get ("is_new" ) # type: ignore[attr-defined]
108113 return user
109114
110- def disconnect (self , * args , ** kwargs ):
115+ def disconnect (self , * args , ** kwargs ) -> dict :
111116 pipeline = self .strategy .get_disconnect_pipeline (self )
112117 kwargs ["name" ] = self .name
113118 kwargs ["user_storage" ] = self .strategy .storage .user
114119 return self .run_pipeline (pipeline , * args , ** kwargs )
115120
116- def run_pipeline (self , pipeline : list [str ], pipeline_index = 0 , * args , ** kwargs ):
121+ def run_pipeline (
122+ self , pipeline : list [str ], pipeline_index = 0 , * args , ** kwargs
123+ ) -> dict :
117124 out = kwargs .copy ()
118125 out .setdefault ("strategy" , self .strategy )
119126 out .setdefault ("backend" , out .pop (self .name , None ) or self )
0 commit comments