@@ -14,6 +14,34 @@ def setUp(self):
1414 userinfo_url = "https://example.com/userinfo" # Added userinfo_url
1515 )
1616
17+ def test_register (self ):
18+ # Test the register method
19+ url = self .oauth .register (state = "xyz" , scope = ["openid" , "profile" , "email" ])
20+ parsed_url = urlparse (url )
21+ query_params = parse_qs (parsed_url .query )
22+
23+ self .assertEqual (parsed_url .scheme , "https" )
24+ self .assertEqual (parsed_url .netloc , "example.com" )
25+ self .assertEqual (parsed_url .path , "/auth/register" )
26+ self .assertEqual (query_params ["client_id" ][0 ], "test_client_id" )
27+ self .assertEqual (query_params ["redirect_uri" ][0 ], "http://localhost/callback" )
28+ self .assertEqual (query_params ["scope" ][0 ], "openid profile email" ) # Expected scope
29+ self .assertEqual (query_params ["state" ][0 ], "xyz" )
30+
31+ def test_login (self ):
32+ # Test the login method
33+ url = self .oauth .login (state = "xyz" , scope = ["openid" , "profile" , "email" ])
34+ parsed_url = urlparse (url )
35+ query_params = parse_qs (parsed_url .query )
36+
37+ self .assertEqual (parsed_url .scheme , "https" )
38+ self .assertEqual (parsed_url .netloc , "example.com" )
39+ self .assertEqual (parsed_url .path , "/auth/login" )
40+ self .assertEqual (query_params ["client_id" ][0 ], "test_client_id" )
41+ self .assertEqual (query_params ["redirect_uri" ][0 ], "http://localhost/callback" )
42+ self .assertEqual (query_params ["scope" ][0 ], "openid profile email" ) # Expected scope
43+ self .assertEqual (query_params ["state" ][0 ], "xyz" )
44+
1745 def test_get_login_url (self ):
1846 # Explicitly specify the scope
1947 url = self .oauth .get_login_url (state = "xyz" , scope = ["openid" , "profile" , "email" ])
@@ -57,7 +85,8 @@ def test_get_tokens_for_core(self):
5785 self .assertEqual (tokens ["refresh_token" ], "test_refresh_token" )
5886
5987 def test_logout (self ):
60- logout_url = self .oauth .logout ("test_user_id" )
88+ # Test the logout method
89+ logout_url = self .oauth .logout (state = "xyz" )
6190 parsed_url = urlparse (logout_url )
6291 query_params = parse_qs (parsed_url .query )
6392
@@ -66,6 +95,7 @@ def test_logout(self):
6695 self .assertEqual (parsed_url .path , "/logout" )
6796 self .assertEqual (query_params ["client_id" ][0 ], "test_client_id" )
6897 self .assertEqual (query_params ["logout_uri" ][0 ], "http://localhost/callback" )
98+ self .assertEqual (query_params ["state" ][0 ], "xyz" )
6999
70- if __name__ == "__main__ " :
71- unittest .main ()
100+ # if __name__ == "_main_ ":
101+ # unittest.main()
0 commit comments