Skip to content

Commit 6de9f41

Browse files
authored
Merge pull request #20 from Jsweb-Tech/csrf_bug_fix
Bug Fix : Issue no #18 CSRF Protection Bypass for JSON APIs fixed
2 parents a9b1db4 + 24064d1 commit 6de9f41

File tree

3 files changed

+166
-9
lines changed

3 files changed

+166
-9
lines changed

Tests/test_csrf_json.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import asyncio
2+
import httpx
3+
import subprocess
4+
import sys
5+
import time
6+
import os
7+
8+
# Construct absolute path to the test application directory
9+
TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
10+
TEST_APP_DIR = os.path.join(TESTS_DIR, "test")
11+
12+
# Ensure the test application is in the python path
13+
sys.path.insert(0, TEST_APP_DIR)
14+
15+
BASE_URL = "http://127.0.0.1:8000"
16+
17+
async def run_csrf_test():
18+
"""
19+
Tests that CSRF protection works correctly for various request types.
20+
"""
21+
print("--- Starting CSRF Logic Test ---")
22+
async with httpx.AsyncClient(base_url=BASE_URL) as client:
23+
try:
24+
# 1. Make a GET request to a page to get a CSRF token from the cookie
25+
print("Step 1: Getting CSRF token from homepage...")
26+
get_response = await client.get("/")
27+
get_response.raise_for_status()
28+
assert "csrf_token" in client.cookies, "CSRF token not found in cookie"
29+
csrf_token = client.cookies["csrf_token"]
30+
print(f" [PASS] CSRF token received: {csrf_token[:10]}...")
31+
32+
# 2. Test POST without any CSRF token (should fail)
33+
print("\nStep 2: Testing POST to /api/test without CSRF token (expecting 403)...")
34+
fail_response = await client.post("/api/test", json={"message": "hello"})
35+
assert fail_response.status_code == 403, f"Expected status 403, but got {fail_response.status_code}"
36+
assert "CSRF token missing or invalid" in fail_response.text
37+
print(" [PASS] Request was correctly forbidden.")
38+
39+
# 3. Test POST with CSRF token in JSON body (should pass)
40+
print("\nStep 3: Testing POST to /api/test with CSRF token in JSON body (expecting 200)...")
41+
payload_with_token = {"message": "hello", "csrf_token": csrf_token}
42+
success_response_body = await client.post("/api/test", json=payload_with_token)
43+
assert success_response_body.status_code == 200, f"Expected status 200, but got {success_response_body.status_code}"
44+
assert success_response_body.json()["message"] == "hello"
45+
print(" [PASS] Request with token in body was successful.")
46+
47+
# 4. Test POST with CSRF token in header (should pass)
48+
print("\nStep 4: Testing POST to /api/test with CSRF token in header (expecting 200)...")
49+
headers = {"X-CSRF-Token": csrf_token}
50+
success_response_header = await client.post("/api/test", json={"message": "world"}, headers=headers)
51+
assert success_response_header.status_code == 200, f"Expected status 200, but got {success_response_header.status_code}"
52+
assert success_response_header.json()["message"] == "world"
53+
print(" [PASS] Request with token in header was successful.")
54+
55+
# 5. Test empty-body POST with CSRF token in header (should pass validation, then redirect)
56+
print("\nStep 5: Testing empty-body POST to /logout with CSRF token in header (expecting 302)...")
57+
# Note: The /logout endpoint redirects after success, so we expect a 302
58+
# We disable auto-redirects to verify the 302 status directly
59+
empty_body_response = await client.post("/logout", headers=headers, follow_redirects=False)
60+
61+
# If we got a 403, the CSRF check failed. If we got a 302, it passed!
62+
assert empty_body_response.status_code == 302, f"Expected status 302 (Redirect), but got {empty_body_response.status_code}. (403 means CSRF failed)"
63+
print(" [PASS] Empty-body request passed CSRF check and redirected.")
64+
65+
except Exception as e:
66+
print(f"\n--- TEST FAILED ---")
67+
print(f"An error occurred: {e}")
68+
import traceback
69+
traceback.print_exc()
70+
return False
71+
72+
print("\n--- ALL CSRF TESTS PASSED ---")
73+
return True
74+
75+
76+
def main():
77+
print("Starting test server...")
78+
server_process = subprocess.Popen(
79+
[sys.executable, "-m", "uvicorn", "app:app"],
80+
cwd=TEST_APP_DIR,
81+
stdout=subprocess.PIPE,
82+
stderr=subprocess.PIPE,
83+
text=True, # Decode stdout/stderr as text
84+
)
85+
86+
# Give the server more time to start up
87+
print("Waiting 5 seconds for server to start...")
88+
time.sleep(5)
89+
90+
# Check if the server process has terminated unexpectedly
91+
if server_process.poll() is not None:
92+
print("\n--- SERVER FAILED TO START ---")
93+
stdout, stderr = server_process.communicate()
94+
print("STDOUT:")
95+
print(stdout)
96+
print("\nSTDERR:")
97+
print(stderr)
98+
sys.exit(1)
99+
100+
print("Server seems to be running. Starting tests.")
101+
test_passed = False
102+
try:
103+
test_passed = asyncio.run(run_csrf_test())
104+
finally:
105+
print("\nStopping test server...")
106+
server_process.terminate()
107+
# Get remaining output
108+
try:
109+
stdout, stderr = server_process.communicate(timeout=5)
110+
print("\n--- Server Output ---")
111+
print("STDOUT:")
112+
print(stdout)
113+
print("\nSTDERR:")
114+
print(stderr)
115+
except subprocess.TimeoutExpired:
116+
print("Server did not terminate gracefully.")
117+
118+
if not test_passed:
119+
print("\nExiting with status 1 due to test failure.")
120+
sys.exit(1)
121+
122+
123+
if __name__ == "__main__":
124+
main()

jsweb/middleware.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from .static import serve_static
44
from .response import Forbidden
5+
import json
56

67
logger = logging.getLogger(__name__)
78

@@ -30,8 +31,10 @@ class CSRFMiddleware(Middleware):
3031
"""
3132
Middleware to protect against Cross-Site Request Forgery (CSRF) attacks.
3233
33-
This middleware checks for a valid CSRF token in POST, PUT, PATCH, and DELETE
34-
requests. It compares a token from the form data against a token stored in a cookie.
34+
This middleware enforces CSRF protection for all state-changing HTTP methods
35+
(POST, PUT, PATCH, DELETE). It requires a valid CSRF token to be present
36+
in the request, either in the 'X-CSRF-Token' header or in the request body
37+
(JSON or Form Data).
3538
"""
3639
async def __call__(self, scope, receive, send):
3740
"""
@@ -52,12 +55,42 @@ async def __call__(self, scope, receive, send):
5255
req = scope['jsweb.request']
5356

5457
if req.method in ("POST", "PUT", "PATCH", "DELETE"):
55-
form = await req.form()
56-
form_token = form.get("csrf_token")
5758
cookie_token = req.cookies.get("csrf_token")
58-
59-
if not form_token or not cookie_token or not secrets.compare_digest(form_token, cookie_token):
60-
logger.error("CSRF VALIDATION FAILED. Tokens do not match or are missing.")
59+
submitted_token = None
60+
61+
# 1. Check header first (Best practice for AJAX/APIs)
62+
submitted_token = req.headers.get("x-csrf-token")
63+
64+
# 2. If no header token, check the body based on content type
65+
if not submitted_token:
66+
content_type = req.headers.get("content-type", "")
67+
68+
if "application/json" in content_type:
69+
try:
70+
# Request.json() safely returns {} for empty/invalid bodies
71+
data = await req.json()
72+
submitted_token = data.get("csrf_token")
73+
except Exception:
74+
# If JSON parsing fails, we treat it as no token found
75+
pass
76+
77+
elif "application/x-www-form-urlencoded" in content_type or "multipart/form-data" in content_type:
78+
try:
79+
# Request.form() safely returns {} for empty/non-form bodies
80+
form = await req.form()
81+
submitted_token = form.get("csrf_token")
82+
except Exception:
83+
# If form parsing fails, we treat it as no token found
84+
pass
85+
86+
# 3. Perform the validation
87+
# Both the cookie token and the submitted token MUST be present and match.
88+
if not cookie_token or not submitted_token or not secrets.compare_digest(submitted_token, cookie_token):
89+
logger.warning(
90+
f"CSRF validation failed for {req.method} {req.path}. "
91+
f"Cookie set: {'Yes' if cookie_token else 'No'}, "
92+
f"Token submitted: {'Yes' if submitted_token else 'No'}."
93+
)
6194
response = Forbidden("CSRF token missing or invalid.")
6295
await response(scope, receive, send)
6396
return

jsweb/request.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(self, scope, receive, app):
3434
self.path = self.scope.get("path", "/")
3535
self.query = self._parse_query(self.scope.get("query_string", b"").decode())
3636
self.headers = self._parse_headers(self.scope.get("headers", []))
37+
self.content_type = self.headers.get("content-type", "")
3738
self.cookies = self._parse_cookies(self.headers)
3839
self.user = None
3940

@@ -101,8 +102,7 @@ async def json(self):
101102
dict: The parsed JSON data.
102103
"""
103104
if self._json is None:
104-
content_type = self.headers.get("content-type", "")
105-
if "application/json" in content_type:
105+
if "application/json" in self.content_type:
106106
try:
107107
body_bytes = await self.body()
108108
self._json = json.loads(body_bytes) if body_bytes else {}

0 commit comments

Comments
 (0)