1
+ import base64
2
+ import hashlib
3
+
1
4
from oauthlib .oauth2 import Client
2
5
from oauthlib .oauth2 import RequestValidator
3
6
4
7
5
8
class TestValidator (RequestValidator ):
9
+ pkce_codes = {}
10
+
6
11
def validate_client_id (self , client_id , request , * args , ** kwargs ):
7
12
return True
8
13
@@ -23,12 +28,30 @@ def confirm_redirect_uri(self, client_id, code, redirect_uri, client, request, *
23
28
return True
24
29
25
30
def validate_code (self , client_id , code , client , request , * args , ** kwargs ):
26
- return True
31
+ stored_challenge = self .pkce_codes .get (code )
32
+ if not stored_challenge :
33
+ return False
34
+
35
+ code_verifier = request .code_verifier
36
+ code_challenge = stored_challenge .get ("code_challenge" )
37
+ code_challenge_method = stored_challenge .get ("code_challenge_method" )
38
+
39
+ computed_challenge = code_verifier
40
+ if code_challenge_method == "S256" :
41
+ sha256 = hashlib .sha256 ()
42
+ sha256 .update (code_verifier .encode ("utf-8" ))
43
+ computed_challenge = base64 .urlsafe_b64encode (sha256 .digest ()).decode ("utf-8" ).replace ("=" , "" )
44
+
45
+ return computed_challenge == code_challenge
27
46
28
47
def validate_scopes (self , client_id , scopes , client , request , * args , ** kwargs ):
29
48
return True
30
49
31
50
def save_authorization_code (self , client_id , code , request , * args , ** kwargs ):
51
+ self .pkce_codes [code .get ("code" )] = dict (
52
+ code_challenge = request .code_challenge ,
53
+ code_challenge_method = request .code_challenge_method ,
54
+ )
32
55
return True
33
56
34
57
def validate_response_type (self , client_id , response_type , client , request , * args , ** kwargs ):
0 commit comments