Skip to content

Commit 3c10b34

Browse files
authored
Merge pull request #183 from python-discord/enforce-filesize-limits
Enforce filesize limits
2 parents 3888578 + 16b1a13 commit 3c10b34

File tree

4 files changed

+76
-7
lines changed

4 files changed

+76
-7
lines changed

requirements/pip-tools.in

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
-c lint.pip
33
-c requirements.pip
44

5-
# Minimum version which supports pip>=22.1
6-
pip-tools>=6.6.1
5+
# Minimum version which supports pip>=23.2
6+
pip-tools>=7.0.0

requirements/pip-tools.pip

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ click==8.1.3
1010
# via pip-tools
1111
packaging==23.0
1212
# via build
13-
pip-tools==6.12.3
13+
pip-tools==7.3.0
1414
# via -r requirements/pip-tools.in
1515
pyproject-hooks==1.0.0
1616
# via build

snekbox/memfs.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def files(
144144
"""
145145
start_time = time.monotonic()
146146
count = 0
147+
total_size = 0
147148
files = glob.iglob(pattern, root_dir=str(self.output), recursive=True, include_hidden=False)
148149
for file in (Path(self.output, f) for f in files):
149150
if timeout and (time.monotonic() - start_time) > timeout:
@@ -152,17 +153,30 @@ def files(
152153
if not file.is_file():
153154
continue
154155

156+
# file.is_file allows file to be a regular file OR a symlink pointing to a regular file.
157+
# It is important that we follow symlinks here, so when we check st_size later it is the
158+
# size of the underlying file rather than of the symlink.
159+
stat = file.stat(follow_symlinks=True)
160+
155161
if exclude_files and (orig_time := exclude_files.get(file)):
156-
new_time = file.stat().st_mtime
162+
new_time = stat.st_mtime
157163
log.info(f"Checking {file.name} ({orig_time=}, {new_time=})")
158-
if file.stat().st_mtime == orig_time:
164+
if stat.st_mtime == orig_time:
159165
log.info(f"Skipping {file.name!r} as it has not been modified")
160166
continue
161167

162168
if count > limit:
163169
log.info(f"Max attachments {limit} reached, skipping remaining files")
164170
break
165171

172+
# Due to sparse files and links the total size could end up being greater
173+
# than the size limit of the tmpfs. Limit the total size to be read to
174+
# prevent high memory usage / OOM when reading files.
175+
total_size += stat.st_size
176+
if total_size > self.instance_size:
177+
log.info(f"Max file size {self.instance_size} reached, skipping remaining files")
178+
break
179+
166180
count += 1
167181
log.info(f"Found valid file for upload {file.name!r}")
168182
yield FileAttachment.from_path(file, relative_to=self.output)

tests/test_nsjail.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,9 @@ def test_file_parsing_timeout(self):
218218
os.symlink("file", f"file{i}")
219219
"""
220220
).strip()
221-
222-
nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=1)
221+
# A value higher than the actual memory needed is used to avoid the limit
222+
# on total file size being reached before the timeout when reading.
223+
nsjail = NsJail(memfs_instance_size=512 * Size.MiB, files_timeout=1)
223224
result = nsjail.python3(["-c", code])
224225
self.assertEqual(result.returncode, None)
225226
self.assertEqual(
@@ -250,6 +251,60 @@ def test_file_parsing_depth_limit(self):
250251
)
251252
self.assertEqual(result.stderr, None)
252253

254+
def test_file_parsing_size_limit_sparse_files(self):
255+
tmpfs_size = 8 * Size.MiB
256+
code = dedent(
257+
f"""
258+
import os
259+
with open("test.txt", "w") as f:
260+
os.truncate(f.fileno(), {tmpfs_size // 2 + 1})
261+
262+
with open("test2.txt", "w") as f:
263+
os.truncate(f.fileno(), {tmpfs_size // 2 + 1})
264+
"""
265+
)
266+
nsjail = NsJail(memfs_instance_size=tmpfs_size, files_timeout=5)
267+
result = nsjail.python3(["-c", code])
268+
self.assertEqual(result.returncode, 0)
269+
self.assertEqual(len(result.files), 1)
270+
271+
def test_file_parsing_size_limit_sparse_files_large(self):
272+
tmpfs_size = 8 * Size.MiB
273+
code = dedent(
274+
f"""
275+
import os
276+
with open("test.txt", "w") as f:
277+
# Use a very large value to ensure the test fails if the
278+
# file is read even if would have been discarded later.
279+
os.truncate(f.fileno(), {1024 * Size.TiB})
280+
"""
281+
)
282+
nsjail = NsJail(memfs_instance_size=tmpfs_size, files_timeout=5)
283+
result = nsjail.python3(["-c", code])
284+
self.assertEqual(result.returncode, 0)
285+
self.assertEqual(len(result.files), 0)
286+
287+
def test_file_parsing_size_limit_symlinks(self):
288+
tmpfs_size = 8 * Size.MiB
289+
code = dedent(
290+
f"""
291+
import os
292+
data = "a" * 1024
293+
size = {tmpfs_size // 8}
294+
295+
with open("file", "w") as f:
296+
for _ in range(size // 1024):
297+
f.write(data)
298+
299+
for i in range(20):
300+
os.symlink("file", f"file{{i}}")
301+
"""
302+
)
303+
nsjail = NsJail(memfs_instance_size=tmpfs_size, files_timeout=5)
304+
result = nsjail.python3(["-c", code])
305+
self.assertEqual(result.returncode, 0)
306+
self.assertEqual(len(result.files), 8)
307+
253308
def test_file_write_error(self):
254309
"""Test errors during file write."""
255310
result = self.nsjail.python3(

0 commit comments

Comments
 (0)