@@ -26,6 +26,23 @@ def dict_to_bytes(d: dict) -> bytes:
26
26
return json .dumps (d , separators = ("," , ":" )).encode ("utf-8" )
27
27
28
28
29
+ def _check_endpoint_match (
30
+ path : str ,
31
+ method : str ,
32
+ endpoints : EndpointMethods ,
33
+ ) -> tuple [bool , Sequence [str ]]:
34
+ """Check if the path and method match any endpoint in the given endpoints map."""
35
+ for pattern , endpoint_methods in endpoints .items ():
36
+ if re .match (pattern , path ):
37
+ for endpoint_method in endpoint_methods :
38
+ required_scopes : Sequence [str ] = []
39
+ if isinstance (endpoint_method , tuple ):
40
+ endpoint_method , required_scopes = endpoint_method
41
+ if method .casefold () == endpoint_method .casefold ():
42
+ return True , required_scopes
43
+ return False , []
44
+
45
+
29
46
def find_match (
30
47
path : str ,
31
48
method : str ,
@@ -34,22 +51,25 @@ def find_match(
34
51
default_public : bool ,
35
52
) -> "MatchResult" :
36
53
"""Check if the given path and method match any of the regex patterns and methods in the endpoints."""
37
- endpoints = private_endpoints if default_public else public_endpoints
38
- for pattern , endpoint_methods in endpoints .items ():
39
- if not re .match (pattern , path ):
40
- continue
41
- for endpoint_method in endpoint_methods :
42
- required_scopes : Sequence [str ] = []
43
- if isinstance (endpoint_method , tuple ):
44
- endpoint_method , required_scopes = endpoint_method
45
- if method .casefold () == endpoint_method .casefold ():
46
- # If default_public, we're looking for a private endpoint.
47
- # If not default_public, we're looking for a public endpoint.
48
- return MatchResult (
49
- is_private = default_public ,
50
- required_scopes = required_scopes ,
51
- )
52
- return MatchResult (is_private = not default_public )
54
+ primary_endpoints = private_endpoints if default_public else public_endpoints
55
+ matched , required_scopes = _check_endpoint_match (path , method , primary_endpoints )
56
+ if matched :
57
+ return MatchResult (
58
+ is_private = default_public ,
59
+ required_scopes = required_scopes ,
60
+ )
61
+
62
+ # If default_public and no match found in private_endpoints, it's public
63
+ if default_public :
64
+ return MatchResult (is_private = False )
65
+
66
+ # If not default_public, check private_endpoints for required scopes
67
+ matched , required_scopes = _check_endpoint_match (path , method , private_endpoints )
68
+ if matched :
69
+ return MatchResult (is_private = True , required_scopes = required_scopes )
70
+
71
+ # Default case: if not default_public and no explicit match, it's private
72
+ return MatchResult (is_private = True )
53
73
54
74
55
75
@dataclass
0 commit comments