Skip to content

Commit 9804a10

Browse files
Merge pull request #173 from python-discord/file-scan-recursion-fix
Fix recursion error during file attachment parsing of deep nested paths
2 parents 9acc6f5 + 90910bd commit 9804a10

File tree

5 files changed

+117
-41
lines changed

5 files changed

+117
-41
lines changed

snekbox/memfs.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Memory filesystem for snekbox."""
22
from __future__ import annotations
33

4+
import glob
45
import logging
6+
import time
57
import warnings
68
import weakref
79
from collections.abc import Generator
@@ -125,6 +127,7 @@ def files(
125127
limit: int,
126128
pattern: str = "**/*",
127129
exclude_files: dict[Path, float] | None = None,
130+
timeout: float | None = None,
128131
) -> Generator[FileAttachment, None, None]:
129132
"""
130133
Yields FileAttachments for files found in the output directory.
@@ -135,12 +138,18 @@ def files(
135138
exclude_files: A dict of Paths and last modified times.
136139
Files will be excluded if their last modified time
137140
is equal to the provided value.
141+
timeout: Maximum time in seconds for file parsing.
142+
Raises:
143+
TimeoutError: If file parsing exceeds timeout.
138144
"""
145+
start_time = time.monotonic()
139146
count = 0
140-
for file in self.output.rglob(pattern):
141-
# Ignore hidden directories or files
142-
if any(part.startswith(".") for part in file.parts):
143-
log.info(f"Skipping hidden path {file!s}")
147+
files = glob.iglob(pattern, root_dir=str(self.output), recursive=True, include_hidden=False)
148+
for file in (Path(self.output, f) for f in files):
149+
if timeout and (time.monotonic() - start_time) > timeout:
150+
raise TimeoutError("File parsing timeout exceeded in MemFS.files")
151+
152+
if not file.is_file():
144153
continue
145154

146155
if exclude_files and (orig_time := exclude_files.get(file)):
@@ -154,17 +163,17 @@ def files(
154163
log.info(f"Max attachments {limit} reached, skipping remaining files")
155164
break
156165

157-
if file.is_file():
158-
count += 1
159-
log.info(f"Found valid file for upload {file.name!r}")
160-
yield FileAttachment.from_path(file, relative_to=self.output)
166+
count += 1
167+
log.info(f"Found valid file for upload {file.name!r}")
168+
yield FileAttachment.from_path(file, relative_to=self.output)
161169

162170
def files_list(
163171
self,
164172
limit: int,
165173
pattern: str,
166174
exclude_files: dict[Path, float] | None = None,
167175
preload_dict: bool = False,
176+
timeout: float | None = None,
168177
) -> list[FileAttachment]:
169178
"""
170179
Return a sorted list of file paths within the output directory.
@@ -176,15 +185,21 @@ def files_list(
176185
Files will be excluded if their last modified time
177186
is equal to the provided value.
178187
preload_dict: Whether to preload as_dict property data.
188+
timeout: Maximum time in seconds for file parsing.
179189
Returns:
180190
List of FileAttachments sorted lexically by path name.
191+
Raises:
192+
TimeoutError: If file parsing exceeds timeout.
181193
"""
194+
start_time = time.monotonic()
182195
res = sorted(
183196
self.files(limit=limit, pattern=pattern, exclude_files=exclude_files),
184197
key=lambda f: f.path,
185198
)
186199
if preload_dict:
187200
for file in res:
201+
if timeout and (time.monotonic() - start_time) > timeout:
202+
raise TimeoutError("File parsing timeout exceeded in MemFS.files_list")
188203
# Loads the cached property as attribute
189204
_ = file.as_dict
190205
return res

snekbox/nsjail.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from snekbox.memfs import MemFS
1616
from snekbox.process import EvalResult
1717
from snekbox.snekio import FileAttachment
18-
from snekbox.utils.timed import timed
18+
from snekbox.utils.timed import time_limit
1919

2020
__all__ = ("NsJail",)
2121

@@ -56,7 +56,7 @@ def __init__(
5656
memfs_home: str = "home",
5757
memfs_output: str = "home",
5858
files_limit: int | None = 100,
59-
files_timeout: float | None = 8,
59+
files_timeout: int | None = 5,
6060
files_pattern: str = "**/[!_]*",
6161
):
6262
"""
@@ -267,21 +267,32 @@ def python3(
267267

268268
# Parse attachments with time limit
269269
try:
270-
attachments = timed(
271-
MemFS.files_list,
272-
(fs, self.files_limit, self.files_pattern),
273-
{
274-
"preload_dict": True,
275-
"exclude_files": files_written,
276-
},
277-
timeout=self.files_timeout,
278-
)
270+
with time_limit(self.files_timeout):
271+
attachments = fs.files_list(
272+
limit=self.files_limit,
273+
pattern=self.files_pattern,
274+
preload_dict=True,
275+
exclude_files=files_written,
276+
timeout=self.files_timeout,
277+
)
279278
log.info(f"Found {len(attachments)} files.")
279+
except RecursionError:
280+
log.info("Recursion error while parsing attachments")
281+
return EvalResult(
282+
args,
283+
None,
284+
"FileParsingError: Exceeded directory depth limit while parsing attachments",
285+
)
280286
except TimeoutError as e:
281287
log.info(f"Exceeded time limit while parsing attachments: {e}")
282288
return EvalResult(
283289
args, None, "TimeoutError: Exceeded time limit while parsing attachments"
284290
)
291+
except Exception as e:
292+
log.exception(f"Unexpected {type(e).__name__} while parse attachments", exc_info=e)
293+
return EvalResult(
294+
args, None, "FileParsingError: Unknown error while parsing attachments"
295+
)
285296

286297
log_lines = nsj_log.read().decode("utf-8").splitlines()
287298
if not log_lines and returncode == 255:

snekbox/utils/timed.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,34 @@
11
"""Calling functions with time limits."""
2-
import multiprocessing
3-
from collections.abc import Callable, Iterable, Mapping
4-
from typing import Any, TypeVar
2+
import signal
3+
from collections.abc import Generator
4+
from contextlib import contextmanager
5+
from typing import TypeVar
56

67
_T = TypeVar("_T")
78
_V = TypeVar("_V")
89

9-
__all__ = ("timed",)
10+
__all__ = ("time_limit",)
1011

1112

12-
def timed(
13-
func: Callable[[_T], _V],
14-
args: Iterable = (),
15-
kwds: Mapping[str, Any] | None = None,
16-
timeout: float | None = None,
17-
) -> _V:
13+
@contextmanager
14+
def time_limit(timeout: int | None = None) -> Generator[None, None, None]:
1815
"""
19-
Call a function with a time limit.
16+
Decorator to call a function with a time limit.
2017
2118
Args:
22-
func: Function to call.
23-
args: Arguments for function.
24-
kwds: Keyword arguments for function.
2519
timeout: Timeout limit in seconds.
2620
2721
Raises:
2822
TimeoutError: If the function call takes longer than `timeout` seconds.
2923
"""
30-
if kwds is None:
31-
kwds = {}
32-
with multiprocessing.Pool(1) as pool:
33-
result = pool.apply_async(func, args, kwds)
34-
try:
35-
return result.get(timeout)
36-
except multiprocessing.TimeoutError as e:
37-
raise TimeoutError(f"Call to {func.__name__} timed out after {timeout} seconds.") from e
24+
25+
def signal_handler(_signum, _frame):
26+
raise TimeoutError(f"time_limit call timed out after {timeout} seconds.")
27+
28+
signal.signal(signal.SIGALRM, signal_handler)
29+
signal.alarm(timeout)
30+
31+
try:
32+
yield
33+
finally:
34+
signal.alarm(0)

tests/test_nsjail.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,29 @@ def test_file_parsing_timeout(self):
227227
)
228228
self.assertEqual(result.stderr, None)
229229

230+
def test_file_parsing_depth_limit(self):
231+
code = dedent(
232+
"""
233+
import os
234+
235+
x = ""
236+
for _ in range(1000):
237+
x += "a/"
238+
os.mkdir(x)
239+
240+
open(f"{x}test.txt", "w").write("test")
241+
"""
242+
).strip()
243+
244+
nsjail = NsJail(memfs_instance_size=32 * Size.MiB, files_timeout=5)
245+
result = nsjail.python3(["-c", code])
246+
self.assertEqual(result.returncode, None)
247+
self.assertEqual(
248+
result.stdout,
249+
"FileParsingError: Exceeded directory depth limit while parsing attachments",
250+
)
251+
self.assertEqual(result.stderr, None)
252+
230253
def test_file_write_error(self):
231254
"""Test errors during file write."""
232255
result = self.nsjail.python3(

tests/test_timed.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import math
2+
import time
3+
from unittest import TestCase
4+
5+
from snekbox.utils.timed import time_limit
6+
7+
8+
class TimedTests(TestCase):
9+
def test_sleep(self):
10+
"""Test that a sleep can be interrupted."""
11+
_finished = False
12+
start = time.perf_counter()
13+
with self.assertRaises(TimeoutError):
14+
with time_limit(1):
15+
time.sleep(2)
16+
_finished = True
17+
end = time.perf_counter()
18+
self.assertLess(end - start, 2)
19+
self.assertFalse(_finished)
20+
21+
def test_iter(self):
22+
"""Test that a long-running built-in function can be interrupted."""
23+
_result = 0
24+
start = time.perf_counter()
25+
with self.assertRaises(TimeoutError):
26+
with time_limit(1):
27+
_result = math.factorial(2**30)
28+
end = time.perf_counter()
29+
self.assertEqual(_result, 0)
30+
self.assertLess(end - start, 2)

0 commit comments

Comments
 (0)