Skip to content

Commit 24064d1

Browse files
committed
Fix : CSRF error
1 parent ba2e5d9 commit 24064d1

File tree

3 files changed

+82
-48
lines changed

3 files changed

+82
-48
lines changed

Tests/test_csrf_json.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,35 +16,51 @@
1616

1717
async def run_csrf_test():
1818
"""
19-
Tests that CSRF protection works correctly for JSON endpoints.
19+
Tests that CSRF protection works correctly for various request types.
2020
"""
21-
print("--- Starting CSRF JSON Test ---")
21+
print("--- Starting CSRF Logic Test ---")
2222
async with httpx.AsyncClient(base_url=BASE_URL) as client:
2323
try:
24-
# 1. Make a GET request to a page to get a CSRF token
24+
# 1. Make a GET request to a page to get a CSRF token from the cookie
2525
print("Step 1: Getting CSRF token from homepage...")
2626
get_response = await client.get("/")
2727
get_response.raise_for_status()
2828
assert "csrf_token" in client.cookies, "CSRF token not found in cookie"
2929
csrf_token = client.cookies["csrf_token"]
30-
print(" [PASS] CSRF token received.")
30+
print(f" [PASS] CSRF token received: {csrf_token[:10]}...")
3131

32-
# 2. Send a POST request to the JSON endpoint WITHOUT a CSRF token
32+
# 2. Test POST without any CSRF token (should fail)
3333
print("\nStep 2: Testing POST to /api/test without CSRF token (expecting 403)...")
34-
payload = {"message": "hello"}
35-
fail_response = await client.post("/api/test", json=payload)
34+
fail_response = await client.post("/api/test", json={"message": "hello"})
3635
assert fail_response.status_code == 403, f"Expected status 403, but got {fail_response.status_code}"
3736
assert "CSRF token missing or invalid" in fail_response.text
3837
print(" [PASS] Request was correctly forbidden.")
3938

40-
# 3. Send a POST request to the JSON endpoint WITH the correct CSRF token
41-
print("\nStep 3: Testing POST to /api/test with CSRF token (expecting 200)...")
39+
# 3. Test POST with CSRF token in JSON body (should pass)
40+
print("\nStep 3: Testing POST to /api/test with CSRF token in JSON body (expecting 200)...")
4241
payload_with_token = {"message": "hello", "csrf_token": csrf_token}
43-
success_response = await client.post("/api/test", json=payload_with_token)
44-
assert success_response.status_code == 200, f"Expected status 200, but got {success_response.status_code}"
45-
response_json = success_response.json()
46-
assert response_json["message"] == "hello"
47-
print(" [PASS] Request was successful.")
42+
success_response_body = await client.post("/api/test", json=payload_with_token)
43+
assert success_response_body.status_code == 200, f"Expected status 200, but got {success_response_body.status_code}"
44+
assert success_response_body.json()["message"] == "hello"
45+
print(" [PASS] Request with token in body was successful.")
46+
47+
# 4. Test POST with CSRF token in header (should pass)
48+
print("\nStep 4: Testing POST to /api/test with CSRF token in header (expecting 200)...")
49+
headers = {"X-CSRF-Token": csrf_token}
50+
success_response_header = await client.post("/api/test", json={"message": "world"}, headers=headers)
51+
assert success_response_header.status_code == 200, f"Expected status 200, but got {success_response_header.status_code}"
52+
assert success_response_header.json()["message"] == "world"
53+
print(" [PASS] Request with token in header was successful.")
54+
55+
# 5. Test empty-body POST with CSRF token in header (should pass validation, then redirect)
56+
print("\nStep 5: Testing empty-body POST to /logout with CSRF token in header (expecting 302)...")
57+
# Note: The /logout endpoint redirects after success, so we expect a 302
58+
# We disable auto-redirects to verify the 302 status directly
59+
empty_body_response = await client.post("/logout", headers=headers, follow_redirects=False)
60+
61+
# If we got a 403, the CSRF check failed. If we got a 302, it passed!
62+
assert empty_body_response.status_code == 302, f"Expected status 302 (Redirect), but got {empty_body_response.status_code}. (403 means CSRF failed)"
63+
print(" [PASS] Empty-body request passed CSRF check and redirected.")
4864

4965
except Exception as e:
5066
print(f"\n--- TEST FAILED ---")
@@ -53,7 +69,7 @@ async def run_csrf_test():
5369
traceback.print_exc()
5470
return False
5571

56-
print("\n--- TEST PASSED ---")
72+
print("\n--- ALL CSRF TESTS PASSED ---")
5773
return True
5874

5975

@@ -89,13 +105,15 @@ def main():
89105
print("\nStopping test server...")
90106
server_process.terminate()
91107
# Get remaining output
92-
stdout, stderr = server_process.communicate(timeout=5)
93-
94-
print("\n--- Server Output ---")
95-
print("STDOUT:")
96-
print(stdout)
97-
print("\nSTDERR:")
98-
print(stderr)
108+
try:
109+
stdout, stderr = server_process.communicate(timeout=5)
110+
print("\n--- Server Output ---")
111+
print("STDOUT:")
112+
print(stdout)
113+
print("\nSTDERR:")
114+
print(stderr)
115+
except subprocess.TimeoutExpired:
116+
print("Server did not terminate gracefully.")
99117

100118
if not test_passed:
101119
print("\nExiting with status 1 due to test failure.")

jsweb/middleware.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from .static import serve_static
44
from .response import Forbidden
5+
import json
56

67
logger = logging.getLogger(__name__)
78

@@ -30,10 +31,11 @@ class CSRFMiddleware(Middleware):
3031
"""
3132
Middleware to protect against Cross-Site Request Forgery (CSRF) attacks.
3233
33-
This middleware checks for a valid CSRF token in POST, PUT, PATCH, and DELETE
34-
requests. It supports both form-based and JSON-based requests.
34+
This middleware enforces CSRF protection for all state-changing HTTP methods
35+
(POST, PUT, PATCH, DELETE). It requires a valid CSRF token to be present
36+
in the request, either in the 'X-CSRF-Token' header or in the request body
37+
(JSON or Form Data).
3538
"""
36-
3739
async def __call__(self, scope, receive, send):
3840
"""
3941
Validates the CSRF token for state-changing HTTP methods.
@@ -53,34 +55,48 @@ async def __call__(self, scope, receive, send):
5355
req = scope['jsweb.request']
5456

5557
if req.method in ("POST", "PUT", "PATCH", "DELETE"):
56-
token = await self._get_token_from_request(req)
5758
cookie_token = req.cookies.get("csrf_token")
58-
59-
if not token or not cookie_token or not secrets.compare_digest(token, cookie_token):
60-
logger.error("CSRF VALIDATION FAILED. Tokens do not match or are missing.")
59+
submitted_token = None
60+
61+
# 1. Check header first (Best practice for AJAX/APIs)
62+
submitted_token = req.headers.get("x-csrf-token")
63+
64+
# 2. If no header token, check the body based on content type
65+
if not submitted_token:
66+
content_type = req.headers.get("content-type", "")
67+
68+
if "application/json" in content_type:
69+
try:
70+
# Request.json() safely returns {} for empty/invalid bodies
71+
data = await req.json()
72+
submitted_token = data.get("csrf_token")
73+
except Exception:
74+
# If JSON parsing fails, we treat it as no token found
75+
pass
76+
77+
elif "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type:
78+
try:
79+
# Request.form() safely returns {} for empty/non-form bodies
80+
form = await req.form()
81+
submitted_token = form.get("csrf_token")
82+
except Exception:
83+
# If form parsing fails, we treat it as no token found
84+
pass
85+
86+
# 3. Perform the validation
87+
# Both the cookie token and the submitted token MUST be present and match.
88+
if not cookie_token or not submitted_token or not secrets.compare_digest(submitted_token, cookie_token):
89+
logger.warning(
90+
f"CSRF validation failed for {req.method} {req.path}. "
91+
f"Cookie set: {'Yes' if cookie_token else 'No'}, "
92+
f"Token submitted: {'Yes' if submitted_token else 'No'}."
93+
)
6194
response = Forbidden("CSRF token missing or invalid.")
6295
await response(scope, receive, send)
6396
return
6497

6598
await self.app(scope, receive, send)
6699

67-
async def _get_token_from_request(self, req):
68-
"""
69-
Extracts the CSRF token from the request, handling both JSON and form data.
70-
"""
71-
content_type = req.headers.get("content-type", "")
72-
if "application/json" in content_type:
73-
try:
74-
data = await req.json()
75-
return data.get("csrf_token")
76-
except Exception:
77-
# In case of malformed JSON, treat as if no token was sent
78-
return None
79-
else:
80-
# Fallback for form data
81-
form = await req.form()
82-
return form.get("csrf_token")
83-
84100
class StaticFilesMiddleware(Middleware):
85101
"""
86102
Middleware for serving static files.

jsweb/request.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(self, scope, receive, app):
3434
self.path = self.scope.get("path", "/")
3535
self.query = self._parse_query(self.scope.get("query_string", b"").decode())
3636
self.headers = self._parse_headers(self.scope.get("headers", []))
37+
self.content_type = self.headers.get("content-type", "")
3738
self.cookies = self._parse_cookies(self.headers)
3839
self.user = None
3940

@@ -101,8 +102,7 @@ async def json(self):
101102
dict: The parsed JSON data.
102103
"""
103104
if self._json is None:
104-
content_type = self.headers.get("content-type", "")
105-
if "application/json" in content_type:
105+
if "application/json" in self.content_type:
106106
try:
107107
body_bytes = await self.body()
108108
self._json = json.loads(body_bytes) if body_bytes else {}

0 commit comments

Comments
 (0)