@@ -40,25 +40,19 @@ def test_get_header(self):
4040 )
4141 self .assertEqual (self .backend .get_header (request ), self .fake_header )
4242
43- # Should work with the x_access_token
44- with override_api_settings (AUTH_HEADER_NAME = "HTTP_X_ACCESS_TOKEN" ):
45- # Should pull correct header off request when using X_ACCESS_TOKEN
46- request = self .factory .get (
47- "/test-url/" , HTTP_X_ACCESS_TOKEN = self .fake_header
48- )
49- self .assertEqual (self .backend .get_header (request ), self .fake_header )
50-
51- # Should work for unicode headers when using
52- request = self .factory .get (
53- "/test-url/" , HTTP_X_ACCESS_TOKEN = self .fake_header .decode ("utf-8" )
54- )
55- self .assertEqual (self .backend .get_header (request ), self .fake_header )
43+ @override_api_settings (AUTH_HEADER_NAME = "HTTP_X_ACCESS_TOKEN" )
44+ def test_get_header_x_access_token (self ):
45+ # Should pull correct header off request when using X_ACCESS_TOKEN
46+ request = self .factory .get ("/test-url/" , HTTP_X_ACCESS_TOKEN = self .fake_header )
47+ self .assertEqual (self .backend .get_header (request ), self .fake_header )
48+
49+ # Should work for unicode headers when using
50+ request = self .factory .get (
51+ "/test-url/" , HTTP_X_ACCESS_TOKEN = self .fake_header .decode ("utf-8" )
52+ )
53+ self .assertEqual (self .backend .get_header (request ), self .fake_header )
5654
5755 def test_get_raw_token (self ):
58- # Should return None if header lacks correct type keyword
59- with override_api_settings (AUTH_HEADER_TYPES = "JWT" ):
60- reload (authentication )
61- self .assertIsNone (self .backend .get_raw_token (self .fake_header ))
6256 reload (authentication )
6357
6458 # Should return None if an empty AUTHORIZATION header is sent
@@ -74,14 +68,21 @@ def test_get_raw_token(self):
7468 # Otherwise, should return unvalidated token in header
7569 self .assertEqual (self .backend .get_raw_token (self .fake_header ), self .fake_token )
7670
71+ @override_api_settings (AUTH_HEADER_TYPES = "JWT" )
72+ def test_get_raw_token_incorrect_header_keyword (self ):
73+ # Should return None if header lacks correct type keyword
74+ # AUTH_HEADER_TYPES is "JWT", but header is "Bearer"
75+ reload (authentication )
76+ self .assertIsNone (self .backend .get_raw_token (self .fake_header ))
77+
78+ @override_api_settings (AUTH_HEADER_TYPES = ("JWT" , "Bearer" ))
79+ def test_get_raw_token_multi_header_keyword (self ):
7780 # Should return token if header has one of many valid token types
78- with override_api_settings (AUTH_HEADER_TYPES = ("JWT" , "Bearer" )):
79- reload (authentication )
80- self .assertEqual (
81- self .backend .get_raw_token (self .fake_header ),
82- self .fake_token ,
83- )
8481 reload (authentication )
82+ self .assertEqual (
83+ self .backend .get_raw_token (self .fake_header ),
84+ self .fake_token ,
85+ )
8586
8687 def test_get_validated_token (self ):
8788 # Should raise InvalidToken if token not valid
@@ -96,36 +97,39 @@ def test_get_validated_token(self):
9697 self .backend .get_validated_token (str (token )).payload , token .payload
9798 )
9899
100+ @override_api_settings (
101+ AUTH_TOKEN_CLASSES = ("rest_framework_simplejwt.tokens.AccessToken" ,),
102+ )
103+ def test_get_validated_token_reject_unknown_token (self ):
99104 # Should not accept tokens not included in AUTH_TOKEN_CLASSES
100105 sliding_token = SlidingToken ()
101- with override_api_settings (
102- AUTH_TOKEN_CLASSES = ("rest_framework_simplejwt.tokens.AccessToken" ,)
103- ):
104- with self .assertRaises (InvalidToken ) as e :
105- self .backend .get_validated_token (str (sliding_token ))
106-
107- messages = e .exception .detail ["messages" ]
108- self .assertEqual (1 , len (messages ))
109- self .assertEqual (
110- {
111- "token_class" : "AccessToken" ,
112- "token_type" : "access" ,
113- "message" : "Token has wrong type" ,
114- },
115- messages [0 ],
116- )
106+ with self .assertRaises (InvalidToken ) as e :
107+ self .backend .get_validated_token (str (sliding_token ))
108+
109+ messages = e .exception .detail ["messages" ]
110+ self .assertEqual (1 , len (messages ))
111+ self .assertEqual (
112+ {
113+ "token_class" : "AccessToken" ,
114+ "token_type" : "access" ,
115+ "message" : "Token has wrong type" ,
116+ },
117+ messages [0 ],
118+ )
117119
120+ @override_api_settings (
121+ AUTH_TOKEN_CLASSES = (
122+ "rest_framework_simplejwt.tokens.AccessToken" ,
123+ "rest_framework_simplejwt.tokens.SlidingToken" ,
124+ ),
125+ )
126+ def test_get_validated_token_accept_known_token (self ):
118127 # Should accept tokens included in AUTH_TOKEN_CLASSES
119128 access_token = AccessToken ()
120129 sliding_token = SlidingToken ()
121- with override_api_settings (
122- AUTH_TOKEN_CLASSES = (
123- "rest_framework_simplejwt.tokens.AccessToken" ,
124- "rest_framework_simplejwt.tokens.SlidingToken" ,
125- )
126- ):
127- self .backend .get_validated_token (str (access_token ))
128- self .backend .get_validated_token (str (sliding_token ))
130+
131+ self .backend .get_validated_token (str (access_token ))
132+ self .backend .get_validated_token (str (sliding_token ))
129133
130134 def test_get_user (self ):
131135 payload = {"some_other_id" : "foo" }
0 commit comments