Skip to content

Commit 7ca3917

Browse files
committed
Add SIGALRM based time limit
1 parent 6601b36 commit 7ca3917

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

snekbox/nsjail.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from snekbox.memfs import MemFS
1616
from snekbox.process import EvalResult
1717
from snekbox.snekio import FileAttachment
18-
from snekbox.utils.timed import timed
18+
from snekbox.utils.timed import time_limit
1919

2020
__all__ = ("NsJail",)
2121

@@ -267,16 +267,14 @@ def python3(
267267

268268
# Parse attachments with time limit
269269
try:
270-
attachments = timed(
271-
MemFS.files_list,
272-
(fs, self.files_limit, self.files_pattern),
273-
{
274-
"preload_dict": True,
275-
"exclude_files": files_written,
276-
"timeout": self.files_timeout,
277-
},
278-
timeout=self.files_timeout,
279-
)
270+
with time_limit(self.files_timeout):
271+
attachments = fs.files_list(
272+
limit=self.files_limit,
273+
pattern=self.files_pattern,
274+
preload_dict=True,
275+
exclude_files=files_written,
276+
timeout=self.files_timeout,
277+
)
280278
log.info(f"Found {len(attachments)} files.")
281279
except RecursionError:
282280
log.info("Recursion error while parsing attachments")
@@ -290,6 +288,11 @@ def python3(
290288
return EvalResult(
291289
args, None, "TimeoutError: Exceeded time limit while parsing attachments"
292290
)
291+
except Exception as e:
292+
log.error(f"Unexpected {type(e).__name__} while parse attachments: {e}")
293+
return EvalResult(
294+
args, None, "FileParsingError: Unknown error while parsing attachments"
295+
)
293296

294297
log_lines = nsj_log.read().decode("utf-8").splitlines()
295298
if not log_lines and returncode == 255:

snekbox/utils/timed.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""Calling functions with time limits."""
22
import multiprocessing
3-
from collections.abc import Callable, Iterable, Mapping
3+
import signal
4+
from collections.abc import Callable, Generator, Iterable, Mapping
5+
from contextlib import contextmanager
46
from typing import Any, TypeVar
57

68
_T = TypeVar("_T")
79
_V = TypeVar("_V")
810

9-
__all__ = ("timed",)
11+
__all__ = ("timed", "time_limit")
1012

1113

1214
def timed(
@@ -35,3 +37,27 @@ def timed(
3537
return result.get(timeout)
3638
except multiprocessing.TimeoutError as e:
3739
raise TimeoutError(f"Call to {func.__name__} timed out after {timeout} seconds.") from e
40+
41+
42+
@contextmanager
43+
def time_limit(timeout: int | None = None) -> Generator[None, None, None]:
44+
"""
45+
Decorator to call a function with a time limit. Uses SIGALRM, requires a UNIX system.
46+
47+
Args:
48+
timeout: Timeout limit in seconds.
49+
50+
Raises:
51+
TimeoutError: If the function call takes longer than `timeout` seconds.
52+
"""
53+
54+
def signal_handler(signum, frame):
55+
raise TimeoutError(f"time_limit call timed out after {timeout} seconds.")
56+
57+
signal.signal(signal.SIGALRM, signal_handler)
58+
signal.alarm(timeout)
59+
60+
try:
61+
yield
62+
finally:
63+
signal.alarm(0)

0 commit comments

Comments
 (0)