Skip to content

Commit 3493b9c

Browse files
authored
fix: add cache headers for images (#9560)
1 parent c9ebe70 commit 3493b9c

File tree

4 files changed

+311
-8
lines changed

4 files changed

+311
-8
lines changed

middleware/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Server middleware modules"""

middleware/cache_middleware.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Cache control middleware for ComfyUI server"""
2+
3+
from aiohttp import web
4+
from typing import Callable, Awaitable
5+
6+
# Time in seconds
7+
ONE_HOUR: int = 3600
8+
ONE_DAY: int = 86400
9+
IMG_EXTENSIONS = (
10+
".jpg",
11+
".jpeg",
12+
".png",
13+
".ppm",
14+
".bmp",
15+
".pgm",
16+
".tif",
17+
".tiff",
18+
".webp",
19+
)
20+
21+
22+
@web.middleware
23+
async def cache_control(
24+
request: web.Request, handler: Callable[[web.Request], Awaitable[web.Response]]
25+
) -> web.Response:
26+
"""Cache control middleware that sets appropriate cache headers based on file type and response status"""
27+
response: web.Response = await handler(request)
28+
29+
if (
30+
request.path.endswith(".js")
31+
or request.path.endswith(".css")
32+
or request.path.endswith("index.json")
33+
):
34+
response.headers.setdefault("Cache-Control", "no-cache")
35+
return response
36+
37+
# Early return for non-image files - no cache headers needed
38+
if not request.path.lower().endswith(IMG_EXTENSIONS):
39+
return response
40+
41+
# Handle image files
42+
if response.status == 404:
43+
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_HOUR}")
44+
elif response.status in (200, 201, 202, 203, 204, 205, 206, 301, 308):
45+
# Success responses and permanent redirects - cache for 1 day
46+
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_DAY}")
47+
elif response.status in (302, 303, 307):
48+
# Temporary redirects - no cache
49+
response.headers.setdefault("Cache-Control", "no-cache")
50+
# Note: 304 Not Modified falls through - no cache headers set
51+
52+
return response

server.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,15 @@
3939
from api_server.routes.internal.internal_routes import InternalRoutes
4040
from protocol import BinaryEventTypes
4141

42+
# Import cache control middleware
43+
from middleware.cache_middleware import cache_control
44+
4245
async def send_socket_catch_exception(function, message):
4346
try:
4447
await function(message)
4548
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
4649
logging.warning("send error: {}".format(err))
4750

48-
@web.middleware
49-
async def cache_control(request: web.Request, handler):
50-
response: web.Response = await handler(request)
51-
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
52-
response.headers.setdefault('Cache-Control', 'no-cache')
53-
return response
54-
55-
5651
@web.middleware
5752
async def compress_body(request: web.Request, handler):
5853
accept_encoding = request.headers.get("Accept-Encoding", "")
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
"""Tests for server cache control middleware"""
2+
3+
import pytest
4+
from aiohttp import web
5+
from aiohttp.test_utils import make_mocked_request
6+
from typing import Dict, Any
7+
8+
from middleware.cache_middleware import cache_control, ONE_HOUR, ONE_DAY, IMG_EXTENSIONS
9+
10+
pytestmark = pytest.mark.asyncio # Apply asyncio mark to all tests
11+
12+
# Test configuration data
13+
CACHE_SCENARIOS = [
14+
# Image file scenarios
15+
{
16+
"name": "image_200_status",
17+
"path": "/test.jpg",
18+
"status": 200,
19+
"expected_cache": f"public, max-age={ONE_DAY}",
20+
"should_have_header": True,
21+
},
22+
{
23+
"name": "image_404_status",
24+
"path": "/missing.jpg",
25+
"status": 404,
26+
"expected_cache": f"public, max-age={ONE_HOUR}",
27+
"should_have_header": True,
28+
},
29+
# JavaScript/CSS scenarios
30+
{
31+
"name": "js_no_cache",
32+
"path": "/script.js",
33+
"status": 200,
34+
"expected_cache": "no-cache",
35+
"should_have_header": True,
36+
},
37+
{
38+
"name": "css_no_cache",
39+
"path": "/styles.css",
40+
"status": 200,
41+
"expected_cache": "no-cache",
42+
"should_have_header": True,
43+
},
44+
{
45+
"name": "index_json_no_cache",
46+
"path": "/api/index.json",
47+
"status": 200,
48+
"expected_cache": "no-cache",
49+
"should_have_header": True,
50+
},
51+
# Non-matching files
52+
{
53+
"name": "html_no_header",
54+
"path": "/index.html",
55+
"status": 200,
56+
"expected_cache": None,
57+
"should_have_header": False,
58+
},
59+
{
60+
"name": "txt_no_header",
61+
"path": "/data.txt",
62+
"status": 200,
63+
"expected_cache": None,
64+
"should_have_header": False,
65+
},
66+
{
67+
"name": "api_endpoint_no_header",
68+
"path": "/api/endpoint",
69+
"status": 200,
70+
"expected_cache": None,
71+
"should_have_header": False,
72+
},
73+
{
74+
"name": "pdf_no_header",
75+
"path": "/file.pdf",
76+
"status": 200,
77+
"expected_cache": None,
78+
"should_have_header": False,
79+
},
80+
]
81+
82+
# Status code scenarios for images
83+
IMAGE_STATUS_SCENARIOS = [
84+
# Success statuses get long cache
85+
{"status": 200, "expected": f"public, max-age={ONE_DAY}"},
86+
{"status": 201, "expected": f"public, max-age={ONE_DAY}"},
87+
{"status": 202, "expected": f"public, max-age={ONE_DAY}"},
88+
{"status": 204, "expected": f"public, max-age={ONE_DAY}"},
89+
{"status": 206, "expected": f"public, max-age={ONE_DAY}"},
90+
# Permanent redirects get long cache
91+
{"status": 301, "expected": f"public, max-age={ONE_DAY}"},
92+
{"status": 308, "expected": f"public, max-age={ONE_DAY}"},
93+
# Temporary redirects get no cache
94+
{"status": 302, "expected": "no-cache"},
95+
{"status": 303, "expected": "no-cache"},
96+
{"status": 307, "expected": "no-cache"},
97+
# 404 gets short cache
98+
{"status": 404, "expected": f"public, max-age={ONE_HOUR}"},
99+
]
100+
101+
# Case sensitivity test paths
102+
CASE_SENSITIVITY_PATHS = ["/image.JPG", "/photo.PNG", "/pic.JpEg"]
103+
104+
# Edge case test paths
105+
EDGE_CASE_PATHS = [
106+
{
107+
"name": "query_strings_ignored",
108+
"path": "/image.jpg?v=123&size=large",
109+
"expected": f"public, max-age={ONE_DAY}",
110+
},
111+
{
112+
"name": "multiple_dots_in_path",
113+
"path": "/image.min.jpg",
114+
"expected": f"public, max-age={ONE_DAY}",
115+
},
116+
{
117+
"name": "nested_paths_with_images",
118+
"path": "/static/images/photo.jpg",
119+
"expected": f"public, max-age={ONE_DAY}",
120+
},
121+
]
122+
123+
124+
class TestCacheControl:
125+
"""Test cache control middleware functionality"""
126+
127+
@pytest.fixture
128+
def status_handler_factory(self):
129+
"""Create a factory for handlers that return specific status codes"""
130+
131+
def factory(status: int, headers: Dict[str, str] = None):
132+
async def handler(request):
133+
return web.Response(status=status, headers=headers or {})
134+
135+
return handler
136+
137+
return factory
138+
139+
@pytest.fixture
140+
def mock_handler(self, status_handler_factory):
141+
"""Create a mock handler that returns a response with 200 status"""
142+
return status_handler_factory(200)
143+
144+
@pytest.fixture
145+
def handler_with_existing_cache(self, status_handler_factory):
146+
"""Create a handler that returns response with existing Cache-Control header"""
147+
return status_handler_factory(200, {"Cache-Control": "max-age=3600"})
148+
149+
async def assert_cache_header(
150+
self,
151+
response: web.Response,
152+
expected_cache: str = None,
153+
should_have_header: bool = True,
154+
):
155+
"""Helper to assert cache control headers"""
156+
if should_have_header:
157+
assert "Cache-Control" in response.headers
158+
if expected_cache:
159+
assert response.headers["Cache-Control"] == expected_cache
160+
else:
161+
assert "Cache-Control" not in response.headers
162+
163+
# Parameterized tests
164+
@pytest.mark.parametrize("scenario", CACHE_SCENARIOS, ids=lambda x: x["name"])
165+
async def test_cache_control_scenarios(
166+
self, scenario: Dict[str, Any], status_handler_factory
167+
):
168+
"""Test various cache control scenarios"""
169+
handler = status_handler_factory(scenario["status"])
170+
request = make_mocked_request("GET", scenario["path"])
171+
response = await cache_control(request, handler)
172+
173+
assert response.status == scenario["status"]
174+
await self.assert_cache_header(
175+
response, scenario["expected_cache"], scenario["should_have_header"]
176+
)
177+
178+
@pytest.mark.parametrize("ext", IMG_EXTENSIONS)
179+
async def test_all_image_extensions(self, ext: str, mock_handler):
180+
"""Test all defined image extensions are handled correctly"""
181+
request = make_mocked_request("GET", f"/image{ext}")
182+
response = await cache_control(request, mock_handler)
183+
184+
assert response.status == 200
185+
assert "Cache-Control" in response.headers
186+
assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}"
187+
188+
@pytest.mark.parametrize(
189+
"status_scenario", IMAGE_STATUS_SCENARIOS, ids=lambda x: f"status_{x['status']}"
190+
)
191+
async def test_image_status_codes(
192+
self, status_scenario: Dict[str, Any], status_handler_factory
193+
):
194+
"""Test different status codes for image requests"""
195+
handler = status_handler_factory(status_scenario["status"])
196+
request = make_mocked_request("GET", "/image.jpg")
197+
response = await cache_control(request, handler)
198+
199+
assert response.status == status_scenario["status"]
200+
assert "Cache-Control" in response.headers
201+
assert response.headers["Cache-Control"] == status_scenario["expected"]
202+
203+
@pytest.mark.parametrize("path", CASE_SENSITIVITY_PATHS)
204+
async def test_case_insensitive_image_extension(self, path: str, mock_handler):
205+
"""Test that image extensions are matched case-insensitively"""
206+
request = make_mocked_request("GET", path)
207+
response = await cache_control(request, mock_handler)
208+
209+
assert "Cache-Control" in response.headers
210+
assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}"
211+
212+
@pytest.mark.parametrize("edge_case", EDGE_CASE_PATHS, ids=lambda x: x["name"])
213+
async def test_edge_cases(self, edge_case: Dict[str, str], mock_handler):
214+
"""Test edge cases like query strings, nested paths, etc."""
215+
request = make_mocked_request("GET", edge_case["path"])
216+
response = await cache_control(request, mock_handler)
217+
218+
assert "Cache-Control" in response.headers
219+
assert response.headers["Cache-Control"] == edge_case["expected"]
220+
221+
# Header preservation tests (special cases not covered by parameterization)
222+
async def test_js_preserves_existing_headers(self, handler_with_existing_cache):
223+
"""Test that .js files preserve existing Cache-Control headers"""
224+
request = make_mocked_request("GET", "/script.js")
225+
response = await cache_control(request, handler_with_existing_cache)
226+
227+
# setdefault should preserve existing header
228+
assert response.headers["Cache-Control"] == "max-age=3600"
229+
230+
async def test_css_preserves_existing_headers(self, handler_with_existing_cache):
231+
"""Test that .css files preserve existing Cache-Control headers"""
232+
request = make_mocked_request("GET", "/styles.css")
233+
response = await cache_control(request, handler_with_existing_cache)
234+
235+
# setdefault should preserve existing header
236+
assert response.headers["Cache-Control"] == "max-age=3600"
237+
238+
async def test_image_preserves_existing_headers(self, status_handler_factory):
239+
"""Test that image cache headers preserve existing Cache-Control"""
240+
handler = status_handler_factory(200, {"Cache-Control": "private, no-cache"})
241+
request = make_mocked_request("GET", "/image.jpg")
242+
response = await cache_control(request, handler)
243+
244+
# setdefault should preserve existing header
245+
assert response.headers["Cache-Control"] == "private, no-cache"
246+
247+
async def test_304_not_modified_inherits_cache(self, status_handler_factory):
248+
"""Test that 304 Not Modified doesn't set cache headers for images"""
249+
handler = status_handler_factory(304, {"Cache-Control": "max-age=7200"})
250+
request = make_mocked_request("GET", "/not-modified.jpg")
251+
response = await cache_control(request, handler)
252+
253+
assert response.status == 304
254+
# Should preserve existing cache header, not override
255+
assert response.headers["Cache-Control"] == "max-age=7200"

0 commit comments

Comments
 (0)