Skip to content

Commit ba2e5d9

Browse files
committed
Bug Fix : Issue no #18 CSRF Protection Bypass for JSON APIs fixed
1 parent a9b1db4 commit ba2e5d9

File tree

2 files changed

+127
-4
lines changed

2 files changed

+127
-4
lines changed

Tests/test_csrf_json.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import asyncio
2+
import httpx
3+
import subprocess
4+
import sys
5+
import time
6+
import os
7+
8+
# Construct absolute path to the test application directory
9+
TESTS_DIR = os.path.dirname(os.path.abspath(__file__))
10+
TEST_APP_DIR = os.path.join(TESTS_DIR, "test")
11+
12+
# Ensure the test application is in the python path
13+
sys.path.insert(0, TEST_APP_DIR)
14+
15+
BASE_URL = "http://127.0.0.1:8000"
16+
17+
async def run_csrf_test():
18+
"""
19+
Tests that CSRF protection works correctly for JSON endpoints.
20+
"""
21+
print("--- Starting CSRF JSON Test ---")
22+
async with httpx.AsyncClient(base_url=BASE_URL) as client:
23+
try:
24+
# 1. Make a GET request to a page to get a CSRF token
25+
print("Step 1: Getting CSRF token from homepage...")
26+
get_response = await client.get("/")
27+
get_response.raise_for_status()
28+
assert "csrf_token" in client.cookies, "CSRF token not found in cookie"
29+
csrf_token = client.cookies["csrf_token"]
30+
print(" [PASS] CSRF token received.")
31+
32+
# 2. Send a POST request to the JSON endpoint WITHOUT a CSRF token
33+
print("\nStep 2: Testing POST to /api/test without CSRF token (expecting 403)...")
34+
payload = {"message": "hello"}
35+
fail_response = await client.post("/api/test", json=payload)
36+
assert fail_response.status_code == 403, f"Expected status 403, but got {fail_response.status_code}"
37+
assert "CSRF token missing or invalid" in fail_response.text
38+
print(" [PASS] Request was correctly forbidden.")
39+
40+
# 3. Send a POST request to the JSON endpoint WITH the correct CSRF token
41+
print("\nStep 3: Testing POST to /api/test with CSRF token (expecting 200)...")
42+
payload_with_token = {"message": "hello", "csrf_token": csrf_token}
43+
success_response = await client.post("/api/test", json=payload_with_token)
44+
assert success_response.status_code == 200, f"Expected status 200, but got {success_response.status_code}"
45+
response_json = success_response.json()
46+
assert response_json["message"] == "hello"
47+
print(" [PASS] Request was successful.")
48+
49+
except Exception as e:
50+
print(f"\n--- TEST FAILED ---")
51+
print(f"An error occurred: {e}")
52+
import traceback
53+
traceback.print_exc()
54+
return False
55+
56+
print("\n--- TEST PASSED ---")
57+
return True
58+
59+
60+
def main():
61+
print("Starting test server...")
62+
server_process = subprocess.Popen(
63+
[sys.executable, "-m", "uvicorn", "app:app"],
64+
cwd=TEST_APP_DIR,
65+
stdout=subprocess.PIPE,
66+
stderr=subprocess.PIPE,
67+
text=True, # Decode stdout/stderr as text
68+
)
69+
70+
# Give the server more time to start up
71+
print("Waiting 5 seconds for server to start...")
72+
time.sleep(5)
73+
74+
# Check if the server process has terminated unexpectedly
75+
if server_process.poll() is not None:
76+
print("\n--- SERVER FAILED TO START ---")
77+
stdout, stderr = server_process.communicate()
78+
print("STDOUT:")
79+
print(stdout)
80+
print("\nSTDERR:")
81+
print(stderr)
82+
sys.exit(1)
83+
84+
print("Server seems to be running. Starting tests.")
85+
test_passed = False
86+
try:
87+
test_passed = asyncio.run(run_csrf_test())
88+
finally:
89+
print("\nStopping test server...")
90+
server_process.terminate()
91+
# Get remaining output
92+
stdout, stderr = server_process.communicate(timeout=5)
93+
94+
print("\n--- Server Output ---")
95+
print("STDOUT:")
96+
print(stdout)
97+
print("\nSTDERR:")
98+
print(stderr)
99+
100+
if not test_passed:
101+
print("\nExiting with status 1 due to test failure.")
102+
sys.exit(1)
103+
104+
105+
if __name__ == "__main__":
106+
main()

jsweb/middleware.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ class CSRFMiddleware(Middleware):
3131
Middleware to protect against Cross-Site Request Forgery (CSRF) attacks.
3232
3333
This middleware checks for a valid CSRF token in POST, PUT, PATCH, and DELETE
34-
requests. It compares a token from the form data against a token stored in a cookie.
34+
requests. It supports both form-based and JSON-based requests.
3535
"""
36+
3637
async def __call__(self, scope, receive, send):
3738
"""
3839
Validates the CSRF token for state-changing HTTP methods.
@@ -52,18 +53,34 @@ async def __call__(self, scope, receive, send):
5253
req = scope['jsweb.request']
5354

5455
if req.method in ("POST", "PUT", "PATCH", "DELETE"):
55-
form = await req.form()
56-
form_token = form.get("csrf_token")
56+
token = await self._get_token_from_request(req)
5757
cookie_token = req.cookies.get("csrf_token")
5858

59-
if not form_token or not cookie_token or not secrets.compare_digest(form_token, cookie_token):
59+
if not token or not cookie_token or not secrets.compare_digest(token, cookie_token):
6060
logger.error("CSRF VALIDATION FAILED. Tokens do not match or are missing.")
6161
response = Forbidden("CSRF token missing or invalid.")
6262
await response(scope, receive, send)
6363
return
6464

6565
await self.app(scope, receive, send)
6666

67+
async def _get_token_from_request(self, req):
68+
"""
69+
Extracts the CSRF token from the request, handling both JSON and form data.
70+
"""
71+
content_type = req.headers.get("content-type", "")
72+
if "application/json" in content_type:
73+
try:
74+
data = await req.json()
75+
return data.get("csrf_token")
76+
except Exception:
77+
# In case of malformed JSON, treat as if no token was sent
78+
return None
79+
else:
80+
# Fallback for form data
81+
form = await req.form()
82+
return form.get("csrf_token")
83+
6784
class StaticFilesMiddleware(Middleware):
6885
"""
6986
Middleware for serving static files.

0 commit comments

Comments
 (0)