Skip to content

Commit fa118e3

Browse files
committed
Use Pydantic for FileMetadata validation
1 parent a7cfb95 commit fa118e3

File tree

10 files changed

+127
-116
lines changed

10 files changed

+127
-116
lines changed

lib/metadata.py

Lines changed: 30 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,51 @@
1-
import json
2-
import math
3-
from dataclasses import dataclass, asdict
41
from starlette.datastructures import Headers
5-
from typing import Optional, Self
2+
from pydantic import BaseModel, Field, field_validator, ByteSize, StrictStr, ConfigDict, AliasChoices
3+
from typing import Optional, Self, Annotated
64

75

8-
@dataclass(frozen=True)
9-
class FileMetadata:
10-
size: int
11-
name: str
12-
content_type: Optional[str] = None
6+
class FileMetadata(BaseModel):
7+
name: StrictStr = Field(description="File name", min_length=2, max_length=255, validation_alias=AliasChoices('name', 'file_name'))
8+
size: ByteSize = Field(description="Size in bytes", gt=0,validation_alias=AliasChoices('size', 'file_size'))
9+
type: StrictStr = Field(description="MIME type", default='application/octet-stream',validation_alias=AliasChoices('type', 'file_type', 'content_type'))
1310

14-
def to_json(self) -> str:
15-
return json.dumps(asdict(self), skipkeys=True)
11+
model_config = ConfigDict(validate_by_name=True, populate_by_name=True)
1612

17-
def to_readable_dict(self) -> dict:
18-
return dict(
19-
file_name=self.name,
20-
file_size=self.format_size(self.size),
21-
file_type=self.format_type(self.content_type)
22-
)
13+
@field_validator('name')
14+
@classmethod
15+
def validate_name(cls, v: str) -> str:
16+
safe_filename = str(v).translate(str.maketrans(':;|*@/\\', ' ')).strip()
17+
return safe_filename.encode('latin-1', 'ignore').decode('utf-8', 'ignore')
2318

2419
@classmethod
2520
def from_json(cls, data: str) -> Self:
26-
return cls(**json.loads(data))
21+
return cls.model_validate_json(data)
22+
23+
def to_json(self) -> str:
24+
return self.model_dump_json()
2725

2826
@classmethod
2927
def get_from_http_headers(cls, headers: Headers, filename: str) -> Self:
28+
"""Create metadata from headers of an HTTP upload request."""
3029
return cls(
31-
name=cls.escape_filename(filename),
32-
size=cls.process_length(headers.get('content-length', '0')),
33-
content_type=headers.get('content-type', '')
30+
name=filename,
31+
size=headers.get('content-length', '0'),
32+
type=headers.get('content-type', '') or None
3433
)
3534

3635
@classmethod
3736
def get_from_json(cls, header: dict) -> Self:
38-
return cls(
39-
name=cls.escape_filename(header['file_name']),
40-
size=cls.process_length(header['file_size']),
41-
content_type=header['file_type']
42-
)
43-
44-
@staticmethod
45-
def escape_filename(filename: str | int) -> str:
46-
"""Escape special characters in the filename."""
47-
safe_filename = str(filename).translate(str.maketrans('', '', ':;|*@/\\'))
48-
return safe_filename.encode('latin-1', 'ignore').decode('utf-8', 'ignore')
49-
50-
@staticmethod
51-
def format_size(size_bytes: int) -> str:
52-
"""Return human-readable file size."""
53-
if size_bytes == 0:
54-
return "0 B"
55-
units = ("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB")
56-
i = math.floor(math.log(size_bytes, 1024))
57-
p = math.pow(1024, i)
58-
s = round(size_bytes / p, 1)
59-
return f"{s} {units[i]}"
60-
61-
@staticmethod
62-
def format_type(content_type: Optional[str]) -> str:
63-
"""Return human-readable file type."""
64-
return content_type or "unknown"
37+
"""Create metadata from a JSON dictionary."""
38+
return cls(**header)
6539

66-
@staticmethod
67-
def process_length(length: str | int) -> int:
68-
"""Convert size string to bytes."""
69-
try:
70-
size = int(str(length).strip().replace(' ', ''))
71-
except ValueError:
72-
raise ValueError(f"Invalid size format: {length}")
73-
if size <= 0:
74-
raise ValueError("File size has to be positive.")
75-
return size
40+
def to_readable_dict(self) -> dict:
41+
return dict(
42+
file_name=self.name,
43+
file_size=self.size.human_readable(),
44+
file_type=self.type,
45+
)
7646

7747
def __str__(self):
78-
return f"{self.name} ({self.size/(1024**2):.1f} MiB - {self.content_type})"
48+
return f"{self.name} ({self.size.human_readable()} - {self.type})"
7949

8050
def __repr__(self):
81-
return f"FileMetadata(name={self.name!r}, size={self.size/(1024**2):.1f}, content_type={self.content_type!r})"
51+
return f"FileMetadata(name={self.name!r}, size={self.size.human_readable()}, type={self.type!r})"

lib/transfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _format_uid(uid: str):
4444
return str(uid).strip().encode('ascii', 'ignore').decode()
4545

4646
def get_file_info(self):
47-
return self.file.name, self.file.size, self.file.content_type
47+
return self.file.name, self.file.size, self.file.type
4848

4949
async def wait_for_event(self, event_name: str, timeout: float = 300.0):
5050
await self.store.wait_for_event(event_name, timeout)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ jinja2>=3.1.6
99
httpx>=0.28.1
1010
rich>=14.0.0
1111
sentry-sdk[fastapi]>=2.32.0
12+
pydantic>=2.0.0

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def get_redis_override(*args, **kwargs) -> redis.Redis:
2828
app.state.redis = redis_client
2929

3030
transport = httpx.ASGITransport(app=app)
31-
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
31+
async with httpx.AsyncClient(transport=transport, base_url="http://testserver") as client:
3232
# Patch the `get_redis` method of the `Store` class
3333
with patch.object(Store, 'get_redis', new=get_redis_override):
3434
print("")
@@ -45,7 +45,7 @@ def get_redis_override(*args, **kwargs) -> redis.Redis:
4545

4646
# Patch the `get_redis` method of the `Store` class
4747
with patch.object(Store, 'get_redis', new=get_redis_override):
48-
with TestClient(app) as client:
48+
with TestClient(app, base_url="http://testserver") as client:
4949
print("")
5050
yield client
5151

tests/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def generate_test_file(size_in_kb: int = 10) -> tuple[bytes, FileMetadata]:
1414
metadata = FileMetadata(
1515
name="test_file.bin",
1616
size=len(content),
17-
content_type="application/octet-stream"
17+
type="application/octet-stream"
1818
)
1919
return content, metadata
2020

tests/test_endpoints.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ async def test_transfer_id_already_used(websocket_client):
4444
ws.send_json({
4545
'file_name': file_metadata.name,
4646
'file_size': file_metadata.size,
47-
'file_type': file_metadata.content_type
47+
'file_type': file_metadata.type
4848
})
4949

5050
# Second attempt should fail with an error message
5151
with websocket_client.websocket_connect(f"/send/{uid}") as ws2:
5252
ws2.send_json({
5353
'file_name': file_metadata.name,
5454
'file_size': file_metadata.size,
55-
'file_type': file_metadata.content_type
55+
'file_type': file_metadata.type
5656
})
5757
response = ws2.receive_text()
5858
assert "Error: Transfer ID is already used." in response
@@ -76,7 +76,7 @@ async def mock_wait_for_client_connected(self):
7676
ws.send_json({
7777
'file_name': file_metadata.name,
7878
'file_size': file_metadata.size,
79-
'file_type': file_metadata.content_type
79+
'file_type': file_metadata.type
8080
})
8181
# This should timeout because we are not starting a receiver
8282
response = ws.receive_text()
@@ -97,7 +97,7 @@ async def sender():
9797
ws.send_json({
9898
'file_name': file_metadata.name,
9999
'file_size': file_metadata.size,
100-
'file_type': file_metadata.content_type
100+
'file_type': file_metadata.type
101101
})
102102
await asyncio.sleep(1.0) # Allow receiver to connect
103103

@@ -142,7 +142,7 @@ async def test_prefetcher_request(test_client: httpx.AsyncClient, websocket_clie
142142
ws.send_json({
143143
'file_name': file_metadata.name,
144144
'file_size': file_metadata.size,
145-
'file_type': file_metadata.content_type
145+
'file_type': file_metadata.type
146146
})
147147
await asyncio.sleep(1.0)
148148

@@ -168,7 +168,7 @@ async def test_browser_download_page(test_client: httpx.AsyncClient, websocket_c
168168
ws.send_json({
169169
'file_name': file_metadata.name,
170170
'file_size': file_metadata.size,
171-
'file_type': file_metadata.content_type
171+
'file_type': file_metadata.type
172172
})
173173
await asyncio.sleep(1.0)
174174

tests/test_journeys.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ async def sender():
1818
ws.send_json({
1919
'file_name': file_metadata.name,
2020
'file_size': file_metadata.size,
21-
'file_type': file_metadata.content_type
21+
'file_type': file_metadata.type
2222
})
2323
await asyncio.sleep(1.0)
2424

@@ -72,7 +72,7 @@ async def test_http_upload_http_download(test_client: httpx.AsyncClient):
7272

7373
async def sender():
7474
headers = {
75-
'Content-Type': file_metadata.content_type,
75+
'Content-Type': file_metadata.type,
7676
'Content-Length': str(file_metadata.size)
7777
}
7878
async with test_client.stream("PUT", f"/{uid}/{file_metadata.name}", content=file_content, headers=headers) as response:

tests/test_unit.py

Lines changed: 76 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,80 @@
11
import pytest
2+
from pydantic import ValidationError
23
from lib.metadata import FileMetadata
34

45

5-
@pytest.mark.parametrize("size_bytes, expected", [
6-
(0, "0 B"),
7-
(1023, "1023.0 B"),
8-
(1024, "1.0 KiB"),
9-
(1536, "1.5 KiB"),
10-
(1024 ** 2, "1.0 MiB"),
11-
(int(1.5 * 1024 ** 2), "1.5 MiB"),
12-
(1024 ** 3, "1.0 GiB"),
13-
])
14-
def test_format_size(size_bytes, expected):
15-
assert FileMetadata.format_size(size_bytes) == expected
16-
17-
18-
@pytest.mark.parametrize("filename, expected", [
19-
("file:name.txt", "filename.txt"),
20-
("file|name.txt", "filename.txt"),
21-
("[email protected]", "filename.txt"),
22-
("file/name.txt", "filename.txt"),
23-
("file\\name.txt", "filename.txt"),
24-
("valid-name.zip", "valid-name.zip"),
25-
])
26-
def test_escape_filename(filename, expected):
27-
assert FileMetadata.escape_filename(filename) == expected
28-
29-
30-
@pytest.mark.parametrize("length, expected", [
31-
("1024", 1024),
32-
(2048, 2048),
33-
(" 4096 ", 4096),
34-
])
35-
def test_process_length(length, expected):
36-
assert FileMetadata.process_length(length) == expected
37-
38-
39-
@pytest.mark.parametrize("invalid_length", ["-100", "0", "abc", "1.5"])
40-
def test_process_length_invalid(invalid_length):
41-
with pytest.raises(ValueError):
42-
FileMetadata.process_length(invalid_length)
6+
def test_file_metadata_creation():
7+
"""Test that FileMetadata can be created with valid data."""
8+
metadata = FileMetadata(
9+
name="test.txt",
10+
size=1024,
11+
content_type="text/plain"
12+
)
13+
assert metadata.name == "test.txt"
14+
assert metadata.size == 1024
15+
assert metadata.type == "text/plain"
16+
17+
18+
def test_file_metadata_validation_invalid_size():
19+
"""Test that FileMetadata validates size field."""
20+
with pytest.raises(ValidationError):
21+
FileMetadata(name="test.txt", size=0)
22+
23+
with pytest.raises(ValidationError):
24+
FileMetadata(name="test.txt", size=-1)
25+
26+
27+
def test_file_metadata_validation_invalid_name():
28+
"""Test that FileMetadata validates name field."""
29+
with pytest.raises(ValidationError):
30+
FileMetadata(name="", size=1024)
31+
32+
33+
def test_file_metadata_json_serialization():
34+
"""Test that FileMetadata can be serialized to and from JSON."""
35+
metadata = FileMetadata(
36+
name="test.txt",
37+
size=1024,
38+
content_type="text/plain"
39+
)
40+
41+
json_str = metadata.to_json()
42+
deserialized = FileMetadata.from_json(json_str)
43+
44+
assert deserialized.name == metadata.name
45+
assert deserialized.size == metadata.size
46+
assert deserialized.type == metadata.type
47+
48+
49+
def test_file_metadata_name_escaping():
50+
"""Test that FileMetadata properly escapes filenames during validation."""
51+
metadata = FileMetadata(
52+
name="file:name.txt",
53+
size=1024
54+
)
55+
assert metadata.name == "file name.txt"
56+
57+
58+
def test_file_metadata_size_conversion():
59+
"""Test that FileMetadata properly converts size strings to integers."""
60+
metadata = FileMetadata(
61+
name="test.txt",
62+
size="1024"
63+
)
64+
assert metadata.size == 1024
65+
assert isinstance(metadata.size, int)
66+
67+
68+
def test_file_metadata_size_human_readable():
69+
"""Test that FileMetadata properly formats sizes using ByteSize's human_readable method."""
70+
metadata = FileMetadata(
71+
name="test.txt",
72+
size=1024
73+
)
74+
assert metadata.size.human_readable() == "1.0KiB"
75+
76+
metadata = FileMetadata(
77+
name="test.txt",
78+
size=1048576
79+
)
80+
assert metadata.size.human_readable() == "1.0MiB"

views/http.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from starlette.background import BackgroundTask
66
from fastapi.exceptions import HTTPException
77
from fastapi.responses import StreamingResponse, PlainTextResponse
8+
from pydantic import ValidationError
89

910
from lib.logging import get_logger
1011
from lib.callbacks import raise_http_exception
@@ -34,8 +35,8 @@ async def http_upload(request: Request, uid: str, filename: str):
3435
except KeyError as e:
3536
log.error("△ Cannot decode file metadata from HTTP headers.", exc_info=e)
3637
raise HTTPException(status_code=400, detail="Cannot decode file metadata from HTTP headers.")
37-
except ValueError as e:
38-
log.error("△ Invalid file size.", exc_info=e)
38+
except ValidationError as e:
39+
log.error("△ Invalid file metadata.", exc_info=e)
3940
raise HTTPException(status_code=400, detail="Invalid file metadata.")
4041

4142
if file.size > 1024**3:
@@ -48,7 +49,7 @@ async def http_upload(request: Request, uid: str, filename: str):
4849
except KeyError as e:
4950
log.warning("△ Transfer ID is already used.")
5051
raise HTTPException(status_code=409, detail="Transfer ID is already used.")
51-
except (ValueError, TypeError) as e:
52+
except (TypeError, ValidationError) as e:
5253
log.error("△ Invalid transfer ID or file metadata.", exc_info=e)
5354
raise HTTPException(status_code=400, detail="Invalid transfer ID or file metadata.")
5455

@@ -90,7 +91,7 @@ async def http_download(request: Request, uid: str):
9091
transfer = await FileTransfer.get(uid)
9192
except KeyError:
9293
raise HTTPException(status_code=404, detail="Transfer not found.")
93-
except (ValueError, TypeError) as e:
94+
except (TypeError, ValidationError) as e:
9495
log.error("▼ Invalid transfer ID.", exc_info=e)
9596
raise HTTPException(status_code=400, detail="Invalid transfer ID.")
9697
else:

0 commit comments

Comments
 (0)