Skip to content

Commit 46dd830

Browse files
committed
Add timeout to MemFS.files and MemFS.files_list for more cooperative cancellation
1 parent 2abd6cd commit 46dd830

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

snekbox/memfs.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Memory filesystem for snekbox."""
22
from __future__ import annotations
33

4+
import glob
45
import logging
6+
import time
57
import warnings
68
import weakref
79
from collections.abc import Generator
@@ -125,6 +127,7 @@ def files(
125127
limit: int,
126128
pattern: str = "**/*",
127129
exclude_files: dict[Path, float] | None = None,
130+
timeout: float | None = None,
128131
) -> Generator[FileAttachment, None, None]:
129132
"""
130133
Yields FileAttachments for files found in the output directory.
@@ -135,12 +138,17 @@ def files(
135138
exclude_files: A dict of Paths and last modified times.
136139
Files will be excluded if their last modified time
137140
is equal to the provided value.
141+
timeout: The maximum time for the file parsing. If exceeded,
142+
a TimeoutError will be raised.
138143
"""
139-
count = 0
140-
for file in self.output.rglob(pattern):
141-
# Ignore hidden directories or files
142-
if any(part.startswith(".") for part in file.parts):
143-
log.info(f"Skipping hidden path {file!s}")
144+
start_time = time.monotonic()
145+
added = 0
146+
files = glob.iglob(pattern, root_dir=str(self.output), recursive=True, include_hidden=False)
147+
for file in (Path(self.output, f) for f in files):
148+
if timeout and (time.monotonic() - start_time) > timeout:
149+
raise TimeoutError("File parsing timeout exceeded in MemFS.files")
150+
151+
if not file.is_file():
144152
continue
145153

146154
if exclude_files and (orig_time := exclude_files.get(file)):
@@ -150,21 +158,21 @@ def files(
150158
log.info(f"Skipping {file.name!r} as it has not been modified")
151159
continue
152160

153-
if count > limit:
161+
if added > limit:
154162
log.info(f"Max attachments {limit} reached, skipping remaining files")
155163
break
156164

157-
if file.is_file():
158-
count += 1
159-
log.info(f"Found valid file for upload {file.name!r}")
160-
yield FileAttachment.from_path(file, relative_to=self.output)
165+
added += 1
166+
log.info(f"Found valid file for upload {file.name!r}")
167+
yield FileAttachment.from_path(file, relative_to=self.output)
161168

162169
def files_list(
163170
self,
164171
limit: int,
165172
pattern: str,
166173
exclude_files: dict[Path, float] | None = None,
167174
preload_dict: bool = False,
175+
timeout: float | None = None,
168176
) -> list[FileAttachment]:
169177
"""
170178
Return a sorted list of file paths within the output directory.
@@ -176,15 +184,20 @@ def files_list(
176184
Files will be excluded if their last modified time
177185
is equal to the provided value.
178186
preload_dict: Whether to preload as_dict property data.
187+
timeout: The maximum time for the file parsing. If exceeded,
188+
a TimeoutError will be raised.
179189
Returns:
180190
List of FileAttachments sorted lexically by path name.
181191
"""
192+
start_time = time.monotonic()
182193
res = sorted(
183194
self.files(limit=limit, pattern=pattern, exclude_files=exclude_files),
184195
key=lambda f: f.path,
185196
)
186197
if preload_dict:
187198
for file in res:
199+
if timeout and (time.monotonic() - start_time) > timeout:
200+
raise TimeoutError("File parsing timeout exceeded in MemFS.files_list")
188201
# Loads the cached property as attribute
189202
_ = file.as_dict
190203
return res

0 commit comments

Comments
 (0)