Skip to content

Commit c31b05e

Browse files
authored
Merge pull request #597 from onekey-sec/landlock
Experimental Landlock based sandboxing
2 parents 495e351 + f882f70 commit c31b05e

File tree

8 files changed

+305
-77
lines changed

8 files changed

+305
-77
lines changed

tests/test_cli.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from pathlib import Path
2-
from typing import Iterable, List, Optional, Type
2+
from typing import Iterable, List, Optional, Tuple, Type
33
from unittest import mock
44

55
import pytest
@@ -13,9 +13,11 @@
1313
from unblob.processing import (
1414
DEFAULT_DEPTH,
1515
DEFAULT_PROCESS_NUM,
16+
DEFAULT_SKIP_EXTENSION,
1617
DEFAULT_SKIP_MAGIC,
1718
ExtractionConfig,
1819
)
20+
from unblob.testing import is_sandbox_available
1921
from unblob.ui import (
2022
NullProgressReporter,
2123
ProgressReporter,
@@ -310,16 +312,16 @@ def test_keep_extracted_chunks(
310312

311313

312314
@pytest.mark.parametrize(
313-
"skip_extension, extracted_files_count",
315+
"skip_extension, expected_skip_extensions",
314316
[
315-
pytest.param([], 5, id="skip-extension-empty"),
316-
pytest.param([""], 5, id="skip-zip-extension-empty-suffix"),
317-
pytest.param([".zip"], 0, id="skip-extension-zip"),
318-
pytest.param([".rlib"], 5, id="skip-extension-rlib"),
317+
pytest.param((), DEFAULT_SKIP_EXTENSION, id="skip-extension-empty"),
318+
pytest.param(("",), ("",), id="skip-zip-extension-empty-suffix"),
319+
pytest.param((".zip",), (".zip",), id="skip-extension-zip"),
320+
pytest.param((".rlib",), (".rlib",), id="skip-extension-rlib"),
319321
],
320322
)
321323
def test_skip_extension(
322-
skip_extension: List[str], extracted_files_count: int, tmp_path: Path
324+
skip_extension: List[str], expected_skip_extensions: Tuple[str, ...], tmp_path: Path
323325
):
324326
runner = CliRunner()
325327
in_path = (
@@ -335,8 +337,12 @@ def test_skip_extension(
335337
for suffix in skip_extension:
336338
args += ["--skip-extension", suffix]
337339
params = [*args, "--extract-dir", str(tmp_path), str(in_path)]
338-
result = runner.invoke(unblob.cli.cli, params)
339-
assert extracted_files_count == len(list(tmp_path.rglob("*")))
340+
process_file_mock = mock.MagicMock()
341+
with mock.patch.object(unblob.cli, "process_file", process_file_mock):
342+
result = runner.invoke(unblob.cli.cli, params)
343+
assert (
344+
process_file_mock.call_args.args[0].skip_extension == expected_skip_extensions
345+
)
340346
assert result.exit_code == 0
341347

342348

@@ -420,3 +426,29 @@ def test_clear_skip_magics(
420426
assert sorted(process_file_mock.call_args.args[0].skip_magic) == sorted(
421427
skip_magic
422428
), fail_message
429+
430+
431+
@pytest.mark.skipif(
432+
not is_sandbox_available(), reason="Sandboxing is only available on Linux"
433+
)
434+
def test_sandbox_escape(tmp_path: Path):
435+
runner = CliRunner()
436+
437+
in_path = tmp_path / "input"
438+
in_path.touch()
439+
extract_dir = tmp_path / "extract-dir"
440+
params = ["--extract-dir", str(extract_dir), str(in_path)]
441+
442+
unrelated_file = tmp_path / "unrelated"
443+
444+
process_file_mock = mock.MagicMock(
445+
side_effect=lambda *_args, **_kwargs: unrelated_file.write_text(
446+
"sandbox escape"
447+
)
448+
)
449+
with mock.patch.object(unblob.cli, "process_file", process_file_mock):
450+
result = runner.invoke(unblob.cli.cli, params)
451+
452+
assert result.exit_code != 0
453+
assert isinstance(result.exception, PermissionError)
454+
process_file_mock.assert_called_once()

tests/test_sandbox.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
5+
from unblob.processing import ExtractionConfig
6+
from unblob.sandbox import Sandbox
7+
from unblob.testing import is_sandbox_available
8+
9+
pytestmark = pytest.mark.skipif(
10+
not is_sandbox_available(), reason="Sandboxing only works on Linux"
11+
)
12+
13+
14+
@pytest.fixture
15+
def log_path(tmp_path):
16+
return tmp_path / "unblob.log"
17+
18+
19+
@pytest.fixture
20+
def extraction_config(extraction_config, tmp_path):
21+
extraction_config.extract_root = tmp_path / "extract" / "root"
22+
# parent has to exist
23+
extraction_config.extract_root.parent.mkdir()
24+
return extraction_config
25+
26+
27+
@pytest.fixture
28+
def sandbox(extraction_config: ExtractionConfig, log_path: Path):
29+
return Sandbox(extraction_config, log_path, None)
30+
31+
32+
def test_necessary_resources_can_be_created_in_sandbox(
33+
sandbox: Sandbox, extraction_config: ExtractionConfig, log_path: Path
34+
):
35+
directory_in_extract_root = extraction_config.extract_root / "path" / "to" / "dir"
36+
file_in_extract_root = directory_in_extract_root / "file"
37+
38+
sandbox.run(extraction_config.extract_root.mkdir, parents=True)
39+
sandbox.run(directory_in_extract_root.mkdir, parents=True)
40+
41+
sandbox.run(file_in_extract_root.touch)
42+
sandbox.run(file_in_extract_root.write_text, "file content")
43+
44+
# log-file is already opened
45+
log_path.touch()
46+
sandbox.run(log_path.write_text, "log line")
47+
48+
49+
def test_access_outside_sandbox_is_not_possible(sandbox: Sandbox, tmp_path: Path):
50+
unrelated_dir = tmp_path / "unrelated" / "path"
51+
unrelated_file = tmp_path / "unrelated-file"
52+
53+
with pytest.raises(PermissionError):
54+
sandbox.run(unrelated_dir.mkdir, parents=True)
55+
56+
with pytest.raises(PermissionError):
57+
sandbox.run(unrelated_file.touch)

unblob/cli.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ExtractionConfig,
3434
process_file,
3535
)
36+
from .sandbox import Sandbox
3637
from .ui import NullProgressReporter, RichConsoleProgressReporter
3738

3839
logger = get_logger()
@@ -321,7 +322,8 @@ def cli(
321322
)
322323

323324
logger.info("Start processing file", file=file)
324-
process_results = process_file(config, file, report_file)
325+
sandbox = Sandbox(config, log_path, report_file)
326+
process_results = sandbox.run(process_file, config, file, report_file)
325327
if verbose == 0:
326328
if skip_extraction:
327329
print_scan_report(process_results)

unblob/pool.py

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
import abc
2+
import contextlib
23
import multiprocessing as mp
34
import os
45
import queue
6+
import signal
57
import sys
68
import threading
79
from multiprocessing.queues import JoinableQueue
8-
from typing import Any, Callable, Union
10+
from typing import Any, Callable, Set, Union
911

1012
from .logging import multiprocessing_breakpoint
1113

1214
mp.set_start_method("fork")
1315

1416

1517
class PoolBase(abc.ABC):
18+
def __init__(self):
19+
with pools_lock:
20+
pools.add(self)
21+
1622
@abc.abstractmethod
1723
def submit(self, args):
1824
pass
@@ -24,15 +30,20 @@ def process_until_done(self):
2430
def start(self):
2531
pass
2632

27-
def close(self):
28-
pass
33+
def close(self, *, immediate=False): # noqa: ARG002
34+
with pools_lock:
35+
pools.remove(self)
2936

3037
def __enter__(self):
3138
self.start()
3239
return self
3340

34-
def __exit__(self, *args):
35-
self.close()
41+
def __exit__(self, exc_type, _exc_value, _tb):
42+
self.close(immediate=exc_type is not None)
43+
44+
45+
pools_lock = threading.Lock()
46+
pools: Set[PoolBase] = set()
3647

3748

3849
class Queue(JoinableQueue):
@@ -53,9 +64,15 @@ class _Sentinel:
5364

5465

5566
def _worker_process(handler, input_, output):
56-
# Creates a new process group, making sure no signals are propagated from the main process to the worker processes.
67+
# Creates a new process group, making sure no signals are
68+
# propagated from the main process to the worker processes.
5769
os.setpgrp()
5870

71+
# Restore default signal handlers, otherwise workers would inherit
72+
# them from main process
73+
signal.signal(signal.SIGTERM, signal.SIG_DFL)
74+
signal.signal(signal.SIGINT, signal.SIG_DFL)
75+
5976
sys.breakpointhook = multiprocessing_breakpoint
6077
while (args := input_.get()) is not _SENTINEL:
6178
result = handler(args)
@@ -71,11 +88,14 @@ def __init__(
7188
*,
7289
result_callback: Callable[["MultiPool", Any], Any],
7390
):
91+
super().__init__()
7492
if process_num <= 0:
7593
raise ValueError("At process_num must be greater than 0")
7694

95+
self._running = False
7796
self._result_callback = result_callback
7897
self._input = Queue(ctx=mp.get_context())
98+
self._input.cancel_join_thread()
7999
self._output = mp.SimpleQueue()
80100
self._procs = [
81101
mp.Process(
@@ -87,14 +107,32 @@ def __init__(
87107
self._tid = threading.get_native_id()
88108

89109
def start(self):
110+
self._running = True
90111
for p in self._procs:
91112
p.start()
92113

93-
def close(self):
94-
self._clear_input_queue()
95-
self._request_workers_to_quit()
96-
self._clear_output_queue()
114+
def close(self, *, immediate=False):
115+
if not self._running:
116+
return
117+
self._running = False
118+
119+
if immediate:
120+
self._terminate_workers()
121+
else:
122+
self._clear_input_queue()
123+
self._request_workers_to_quit()
124+
self._clear_output_queue()
125+
97126
self._wait_for_workers_to_quit()
127+
super().close(immediate=immediate)
128+
129+
def _terminate_workers(self):
130+
for proc in self._procs:
131+
proc.terminate()
132+
133+
self._input.close()
134+
if sys.version_info >= (3, 9):
135+
self._output.close()
98136

99137
def _clear_input_queue(self):
100138
try:
@@ -129,14 +167,16 @@ def submit(self, args):
129167
self._input.put(args)
130168

131169
def process_until_done(self):
132-
while not self._input.is_empty():
133-
result = self._output.get()
134-
self._result_callback(self, result)
135-
self._input.task_done()
170+
with contextlib.suppress(EOFError):
171+
while not self._input.is_empty():
172+
result = self._output.get()
173+
self._result_callback(self, result)
174+
self._input.task_done()
136175

137176

138177
class SinglePool(PoolBase):
139178
def __init__(self, handler, *, result_callback):
179+
super().__init__()
140180
self._handler = handler
141181
self._result_callback = result_callback
142182

@@ -157,3 +197,19 @@ def make_pool(process_num, handler, result_callback) -> Union[SinglePool, MultiP
157197
handler=handler,
158198
result_callback=result_callback,
159199
)
200+
201+
202+
orig_signal_handlers = {}
203+
204+
205+
def _on_terminate(signum, frame):
206+
pools_snapshot = list(pools)
207+
for pool in pools_snapshot:
208+
pool.close(immediate=True)
209+
210+
if callable(orig_signal_handlers[signum]):
211+
orig_signal_handlers[signum](signum, frame)
212+
213+
214+
orig_signal_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, _on_terminate)
215+
orig_signal_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, _on_terminate)

unblob/processing.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
StatReport,
4646
UnknownError,
4747
)
48-
from .signals import terminate_gracefully
4948
from .ui import NullProgressReporter, ProgressReporter
5049

5150
logger = get_logger()
@@ -118,7 +117,6 @@ def get_carve_dir_for(self, path: Path) -> Path:
118117
return self._get_output_path(path.with_name(path.name + self.carve_suffix))
119118

120119

121-
@terminate_gracefully
122120
def process_file(
123121
config: ExtractionConfig, input_path: Path, report_file: Optional[Path] = None
124122
) -> ProcessResult:

0 commit comments

Comments
 (0)