Skip to content

Commit cf84c9f

Browse files
add initial registry for staticfiles (#46)
* add registry for staticfiles * change to using dataclasses instead of ABC * add tests and adjust asset type enum * remove asset tags moving to another branch
1 parent e04cf55 commit cf84c9f

File tree

2 files changed

+289
-0
lines changed

2 files changed

+289
-0
lines changed

src/django_bird/staticfiles.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable
4+
from dataclasses import dataclass
5+
from enum import IntEnum
6+
from pathlib import Path
7+
8+
from django.templatetags.static import static
9+
from django.utils.html import format_html
10+
from django.utils.html import format_html_join
11+
from django.utils.safestring import SafeString
12+
13+
from ._typing import override
14+
15+
16+
class AssetType(IntEnum):
17+
CSS = 1
18+
JS = 2
19+
20+
21+
@dataclass(frozen=True)
22+
class Asset:
23+
path: Path
24+
type: AssetType
25+
26+
@override
27+
def __hash__(self) -> int:
28+
return hash((str(self.path), self.type))
29+
30+
def exists(self) -> bool:
31+
return self.path.exists()
32+
33+
@classmethod
34+
def from_path(cls, path: Path) -> Asset:
35+
match path.suffix.lower():
36+
case ".css":
37+
asset_type = AssetType.CSS
38+
case ".js":
39+
asset_type = AssetType.JS
40+
case _:
41+
raise ValueError(f"Unknown asset type for path: {path}")
42+
return cls(path=path, type=asset_type)
43+
44+
45+
class Registry:
46+
def __init__(self) -> None:
47+
self.assets: set[Asset] = set()
48+
49+
def register(self, asset: Asset | Path) -> None:
50+
if isinstance(asset, Path):
51+
asset = Asset.from_path(asset)
52+
53+
if not asset.exists():
54+
raise FileNotFoundError(f"Asset file not found: {asset.path}")
55+
56+
self.assets.add(asset)
57+
58+
def clear(self) -> None:
59+
self.assets.clear()
60+
61+
def get_assets(self, asset_type: AssetType) -> list[Asset]:
62+
assets = [asset for asset in self.assets if asset.type == asset_type]
63+
return self.sort_assets(assets)
64+
65+
def sort_assets(
66+
self,
67+
assets: list[Asset],
68+
*,
69+
key: Callable[[Asset], str] = lambda a: str(a.path),
70+
) -> list[Asset]:
71+
return sorted(assets, key=key)
72+
73+
def get_format_string(self, asset_type: AssetType) -> str:
74+
match asset_type:
75+
case AssetType.CSS:
76+
return '<link rel="stylesheet" href="{}" type="text/css">'
77+
case AssetType.JS:
78+
return '<script src="{}" type="text/javascript">'
79+
80+
def render(self, asset_type: AssetType) -> SafeString:
81+
assets = self.get_assets(asset_type)
82+
83+
if not assets:
84+
return format_html("")
85+
86+
return format_html_join(
87+
"\n",
88+
self.get_format_string(asset_type),
89+
((static(str(asset.path)),) for asset in assets),
90+
)
91+
92+
93+
registry = Registry()

tests/test_staticfiles.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
5+
import pytest
6+
from django.utils.safestring import SafeString
7+
8+
from django_bird.staticfiles import Asset
9+
from django_bird.staticfiles import AssetType
10+
from django_bird.staticfiles import registry
11+
12+
13+
class TestAsset:
14+
def test_hash(self):
15+
asset1 = Asset(Path("static.css"), AssetType.CSS)
16+
asset2 = Asset(Path("static.css"), AssetType.CSS)
17+
18+
assert asset1 == asset2
19+
assert hash(asset1) == hash(asset2)
20+
21+
assets = {asset1, asset2, Asset(Path("other.css"), AssetType.CSS)}
22+
23+
assert len(assets) == 2
24+
25+
def test_exists(self, tmp_path: Path):
26+
css_file = tmp_path / "test.css"
27+
css_file.touch()
28+
29+
asset = Asset(css_file, AssetType.CSS)
30+
31+
assert asset.exists() is True
32+
33+
def test_exists_nonexistent(self):
34+
missing_asset = Asset(Path("missing.css"), AssetType.CSS)
35+
assert missing_asset.exists() is False
36+
37+
@pytest.mark.parametrize(
38+
"path,expected",
39+
[
40+
(Path("static.css"), Asset(Path("static.css"), AssetType.CSS)),
41+
(Path("static.js"), Asset(Path("static.js"), AssetType.JS)),
42+
(
43+
Path("nested/path/style.css"),
44+
Asset(Path("nested/path/style.css"), AssetType.CSS),
45+
),
46+
(
47+
Path("./relative/script.js"),
48+
Asset(Path("./relative/script.js"), AssetType.JS),
49+
),
50+
(Path("UPPERCASE.CSS"), Asset(Path("UPPERCASE.CSS"), AssetType.CSS)),
51+
(Path("mixed.Js"), Asset(Path("mixed.Js"), AssetType.JS)),
52+
],
53+
)
54+
def test_from_path(self, path, expected):
55+
assert Asset.from_path(path) == expected
56+
57+
def test_from_path_invalid(self):
58+
with pytest.raises(ValueError):
59+
Asset.from_path(Path("static.html"))
60+
61+
62+
class TestRegistry:
63+
@pytest.fixture
64+
def registry(self):
65+
registry.clear()
66+
yield registry
67+
registry.clear()
68+
69+
def test_register_asset(self, registry, tmp_path):
70+
css_file = tmp_path / "test.css"
71+
css_file.touch()
72+
asset = Asset(css_file, AssetType.CSS)
73+
74+
registry.register(asset)
75+
76+
assert asset in registry.assets
77+
78+
def test_register_path(self, registry, tmp_path):
79+
css_file = tmp_path / "test.css"
80+
css_file.touch()
81+
82+
registry.register(css_file)
83+
84+
assert len(registry.assets) == 1
85+
assert next(iter(registry.assets)).path == css_file
86+
87+
def test_register_missing_file(self, registry, tmp_path):
88+
missing_file = tmp_path / "missing.css"
89+
90+
with pytest.raises(FileNotFoundError):
91+
registry.register(missing_file)
92+
93+
def test_clear(self, registry, tmp_path):
94+
css_file = tmp_path / "test.css"
95+
css_file.touch()
96+
js_file = tmp_path / "test.js"
97+
js_file.touch()
98+
registry.register(Asset(css_file, AssetType.CSS))
99+
registry.register(Asset(js_file, AssetType.JS))
100+
101+
assert len(registry.assets) == 2
102+
103+
registry.clear()
104+
105+
assert len(registry.assets) == 0
106+
107+
@pytest.mark.parametrize("asset_type", [AssetType.CSS, AssetType.JS])
108+
def test_get_assets(self, asset_type, registry):
109+
css_asset = Asset(Path("test.css"), AssetType.CSS)
110+
js_asset = Asset(Path("test.js"), AssetType.JS)
111+
112+
registry.assets = {css_asset, js_asset}
113+
114+
assets = registry.get_assets(asset_type)
115+
116+
assert len(assets) == 1
117+
assert all(asset.type == asset_type for asset in assets)
118+
119+
@pytest.mark.parametrize(
120+
"assets,sort_key,expected",
121+
[
122+
(
123+
[
124+
Asset(Path("test/a.css"), AssetType.CSS),
125+
Asset(Path("test/b.css"), AssetType.CSS),
126+
],
127+
None,
128+
["test/a.css", "test/b.css"],
129+
),
130+
(
131+
[
132+
Asset(Path("test/b.css"), AssetType.CSS),
133+
Asset(Path("other/a.css"), AssetType.CSS),
134+
],
135+
lambda a: a.path.name,
136+
["other/a.css", "test/b.css"],
137+
),
138+
([], None, []),
139+
(
140+
[Asset(Path("test.css"), AssetType.CSS)],
141+
None,
142+
["test.css"],
143+
),
144+
],
145+
)
146+
def test_sort_assets(self, assets, sort_key, expected, registry):
147+
kwargs = {}
148+
if sort_key is not None:
149+
kwargs["key"] = sort_key
150+
151+
sorted_assets = registry.sort_assets(assets, **kwargs)
152+
153+
assert [str(a.path) for a in sorted_assets] == expected
154+
155+
@pytest.mark.parametrize(
156+
"asset_type,expected",
157+
[
158+
(AssetType.CSS, '<link rel="stylesheet" href="{}" type="text/css">'),
159+
(AssetType.JS, '<script src="{}" type="text/javascript">'),
160+
],
161+
)
162+
def test_get_format_string(self, asset_type, expected, registry):
163+
assert registry.get_format_string(asset_type) == expected
164+
165+
@pytest.mark.parametrize(
166+
"assets,asset_type,expected",
167+
[
168+
(set(), AssetType.CSS, ""),
169+
(
170+
{Asset(Path("test.css"), AssetType.CSS)},
171+
AssetType.CSS,
172+
('rel="stylesheet"', 'href="test.css"'),
173+
),
174+
(
175+
{Asset(Path("test.js"), AssetType.JS)},
176+
AssetType.JS,
177+
('<script src="test.js"',),
178+
),
179+
(
180+
{
181+
Asset(Path("a.css"), AssetType.CSS),
182+
Asset(Path("b.css"), AssetType.CSS),
183+
},
184+
AssetType.CSS,
185+
('stylesheet" href="a.css"', 'stylesheet" href="b.css"'),
186+
),
187+
],
188+
)
189+
def test_render(self, assets, asset_type, expected, registry):
190+
registry.assets = assets
191+
192+
rendered = registry.render(asset_type)
193+
194+
assert isinstance(rendered, SafeString)
195+
for content in expected:
196+
assert content in rendered

0 commit comments

Comments
 (0)