Skip to content

Commit ec0d87d

Browse files
authored
allow CSPMiddleware default-src to be set (#278)
1 parent 1047227 commit ec0d87d

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

piccolo_api/csp/middleware.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
@dataclass
1212
class CSPConfig:
1313
report_uri: t.Optional[bytes] = None
14+
default_src: str = "self"
1415

1516

1617
class CSPMiddleware:
@@ -27,7 +28,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
2728
async def wrapped_send(message: Message):
2829
if message["type"] == "http.response.start":
2930
headers = message.get("headers", [])
30-
header_value = b"default-src 'self'"
31+
header_value = bytes(
32+
f"default-src: '{self.config.default_src}'", "utf8"
33+
)
3134
if self.config.report_uri:
3235
header_value = (
3336
header_value

tests/csp/test_csp.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,36 @@ async def app(scope, receive, send):
2525

2626
class TestCSPMiddleware(TestCase):
2727
def test_headers(self):
28+
"""
29+
Make sure the headers are added.
30+
"""
2831
wrapped_app = CSPMiddleware(app)
2932

3033
client = TestClient(wrapped_app)
3134
response = client.request("GET", "/")
3235

33-
header_names = response.headers.keys()
34-
3536
# Make sure the headers got added:
36-
self.assertIn("content-security-policy", header_names)
37+
self.assertEqual(
38+
response.headers["content-security-policy"],
39+
"default-src: 'self'",
40+
)
3741

3842
# Make sure the original headers are still intact:
39-
self.assertIn("content-type", header_names)
43+
self.assertEqual(response.headers["content-type"], "text/plain")
44+
45+
def test_default_src(self):
46+
"""
47+
Make sure the `default-src` value can be set.
48+
"""
49+
wrapped_app = CSPMiddleware(app, config=CSPConfig(default_src="none"))
50+
51+
client = TestClient(wrapped_app)
52+
response = client.request("GET", "/")
53+
54+
self.assertEqual(
55+
response.headers.get("content-security-policy"),
56+
"default-src: 'none'",
57+
)
4058

4159
def test_report_uri(self):
4260
wrapped_app = CSPMiddleware(
@@ -46,5 +64,7 @@ def test_report_uri(self):
4664
client = TestClient(wrapped_app)
4765
response = client.request("GET", "/")
4866

49-
header = response.headers["content-security-policy"]
50-
self.assertIn("report-uri", header)
67+
self.assertEqual(
68+
response.headers["content-security-policy"],
69+
"default-src: 'self'; report-uri foo.com",
70+
)

0 commit comments

Comments
 (0)