|
7 | 7 | import os |
8 | 8 | import signal |
9 | 9 | from dataclasses import dataclass, field |
10 | | -from typing import Any, Callable |
| 10 | +from typing import Any, Callable, TYPE_CHECKING |
11 | 11 |
|
12 | 12 | from .exceptions import SandboxError |
13 | 13 | from ._context import SandboxContext |
14 | 14 | from .policy import Policy |
15 | 15 |
|
| 16 | +if TYPE_CHECKING: |
| 17 | + from .sandbox import Sandbox |
| 18 | + |
16 | 19 |
|
17 | 20 | @dataclass |
18 | 21 | class Result: |
@@ -217,3 +220,331 @@ def _drain_and_close(fd: int) -> None: |
217 | 220 | os.close(fd) |
218 | 221 | except OSError: |
219 | 222 | pass |
| 223 | + |
| 224 | + |
| 225 | +# --- Stage / Pipeline --- |
| 226 | + |
| 227 | + |
| 228 | +class Stage: |
| 229 | + """A sandbox bound to a command, not yet running. |
| 230 | +
|
| 231 | + Created by :meth:`Sandbox.cmd`. Stages can be chained into a |
| 232 | + pipeline with the ``|`` operator:: |
| 233 | +
|
| 234 | + result = ( |
| 235 | + Sandbox(policy_a).cmd(["producer"]) |
| 236 | + | Sandbox(policy_b).cmd(["consumer"]) |
| 237 | + ).run() |
| 238 | + """ |
| 239 | + |
| 240 | + __slots__ = ("sandbox", "args") |
| 241 | + |
| 242 | + def __init__(self, sandbox: Sandbox, args: list[str]): |
| 243 | + self.sandbox = sandbox |
| 244 | + self.args = args |
| 245 | + |
| 246 | + def __or__(self, other: Stage | Pipeline) -> Pipeline: |
| 247 | + if isinstance(other, Pipeline): |
| 248 | + return Pipeline([self] + other.stages) |
| 249 | + if isinstance(other, Stage): |
| 250 | + return Pipeline([self, other]) |
| 251 | + return NotImplemented |
| 252 | + |
| 253 | + def run(self, *, timeout: float | None = None) -> Result: |
| 254 | + """Run this single stage. Equivalent to ``sandbox.run(args)``.""" |
| 255 | + return self.sandbox.run(self.args, timeout=timeout) |
| 256 | + |
| 257 | + |
| 258 | +class Pipeline: |
| 259 | + """A chain of stages connected by pipes. |
| 260 | +
|
| 261 | + Each stage's stdout is wired to the next stage's stdin. The |
| 262 | + parent process never holds the inter-stage pipe data — it flows |
| 263 | + through kernel buffers only. |
| 264 | +
|
| 265 | + Created by ``stage_a | stage_b``. |
| 266 | + """ |
| 267 | + |
| 268 | + __slots__ = ("stages",) |
| 269 | + |
| 270 | + def __init__(self, stages: list[Stage]): |
| 271 | + if len(stages) < 2: |
| 272 | + raise ValueError("Pipeline requires at least 2 stages") |
| 273 | + self.stages = stages |
| 274 | + |
| 275 | + def __or__(self, other: Stage | Pipeline) -> Pipeline: |
| 276 | + if isinstance(other, Stage): |
| 277 | + return Pipeline(self.stages + [other]) |
| 278 | + if isinstance(other, Pipeline): |
| 279 | + return Pipeline(self.stages + other.stages) |
| 280 | + return NotImplemented |
| 281 | + |
| 282 | + def run( |
| 283 | + self, |
| 284 | + *, |
| 285 | + stdout: int | None = None, |
| 286 | + timeout: float | None = None, |
| 287 | + ) -> Result: |
| 288 | + """Execute the pipeline. |
| 289 | +
|
| 290 | + Each stage runs in its own sandbox. Stdout of stage N is piped |
| 291 | + to stdin of stage N+1. The parent never reads inter-stage data. |
| 292 | +
|
| 293 | + Args: |
| 294 | + stdout: File descriptor for the final stage's stdout. |
| 295 | + When set, the last stage writes directly to this fd |
| 296 | + and ``result.stdout`` is empty. Use ``sys.stdout.fileno()`` |
| 297 | + to send output to the terminal. |
| 298 | + timeout: Maximum seconds to wait for the entire pipeline. |
| 299 | +
|
| 300 | + Returns: |
| 301 | + Result from the last stage. ``stdout`` contains the last |
| 302 | + stage's output only when the *stdout* parameter is not set. |
| 303 | + """ |
| 304 | + return run_pipeline(self.stages, stdout=stdout, timeout=timeout) |
| 305 | + |
| 306 | + |
| 307 | +def run_pipeline( |
| 308 | + stages: list[Stage], |
| 309 | + *, |
| 310 | + stdout: int | None = None, |
| 311 | + timeout: float | None = None, |
| 312 | +) -> Result: |
| 313 | + """Execute a pipeline of sandboxed stages connected by pipes. |
| 314 | +
|
| 315 | + The parent creates pipes between adjacent stages and closes its |
| 316 | + copies immediately after fork. Data flows through kernel buffers; |
| 317 | + the parent never reads inter-stage data. |
| 318 | +
|
| 319 | + Args: |
| 320 | + stages: Ordered list of stages to execute. |
| 321 | + stdout: File descriptor for the final stage's stdout. |
| 322 | + timeout: Maximum seconds to wait for the entire pipeline. |
| 323 | +
|
| 324 | + Returns: |
| 325 | + Result from the last stage. |
| 326 | + """ |
| 327 | + n = len(stages) |
| 328 | + if n == 0: |
| 329 | + raise ValueError("Pipeline requires at least 1 stage") |
| 330 | + if n == 1: |
| 331 | + return stages[0].run(timeout=timeout) |
| 332 | + |
| 333 | + # Create inter-stage pipes: pipe[i] connects stage i → stage i+1 |
| 334 | + # Each pipe is (read_fd, write_fd) |
| 335 | + pipes = [os.pipe() for _ in range(n - 1)] |
| 336 | + |
| 337 | + # Create stderr pipe for last stage |
| 338 | + last_stderr_r, last_stderr_w = os.pipe() |
| 339 | + |
| 340 | + # Capture stdout of last stage (unless caller provided an fd) |
| 341 | + capture_stdout = stdout is None |
| 342 | + if capture_stdout: |
| 343 | + last_stdout_r, last_stdout_w = os.pipe() |
| 344 | + else: |
| 345 | + last_stdout_r, last_stdout_w = -1, -1 |
| 346 | + |
| 347 | + contexts = [None] * n # type: list |
| 348 | + opened_fds = [] # type: list # track fds we need to clean up on error |
| 349 | + |
| 350 | + try: |
| 351 | + for i, stage in enumerate(stages): |
| 352 | + # Determine this stage's stdin/stdout fds |
| 353 | + stdin_fd = pipes[i - 1][0] if i > 0 else -1 |
| 354 | + if i < n - 1: |
| 355 | + stdout_fd = pipes[i][1] |
| 356 | + elif capture_stdout: |
| 357 | + stdout_fd = last_stdout_w |
| 358 | + else: |
| 359 | + stdout_fd = os.dup(stdout) |
| 360 | + opened_fds.append(stdout_fd) |
| 361 | + stderr_fd = last_stderr_w if i == n - 1 else -1 |
| 362 | + |
| 363 | + # Build the target function for this stage. Capture fds |
| 364 | + # and pipe list by value to avoid closure issues. |
| 365 | + cmd = stage.args |
| 366 | + _stdin = stdin_fd |
| 367 | + _stdout = stdout_fd |
| 368 | + _stderr = stderr_fd |
| 369 | + _pipes = pipes |
| 370 | + _last_stderr_w = last_stderr_w |
| 371 | + _last_stdout_r = last_stdout_r |
| 372 | + _last_stdout_w = last_stdout_w |
| 373 | + |
| 374 | + def _make_target(cmd, _stdin, _stdout, _stderr, |
| 375 | + _pipes, _last_stderr_w, |
| 376 | + _last_stdout_r, _last_stdout_w): |
| 377 | + def _target(): |
| 378 | + # Close all pipe fds the child doesn't need |
| 379 | + for r, w in _pipes: |
| 380 | + if r != _stdin: |
| 381 | + os.close(r) |
| 382 | + if w != _stdout: |
| 383 | + os.close(w) |
| 384 | + if _last_stderr_w >= 0 and _last_stderr_w != _stderr: |
| 385 | + os.close(_last_stderr_w) |
| 386 | + if _last_stdout_r >= 0: |
| 387 | + os.close(_last_stdout_r) |
| 388 | + if _last_stdout_w >= 0 and _last_stdout_w != _stdout: |
| 389 | + os.close(_last_stdout_w) |
| 390 | + |
| 391 | + # Wire stdin |
| 392 | + if _stdin >= 0: |
| 393 | + os.dup2(_stdin, 0) |
| 394 | + os.close(_stdin) |
| 395 | + |
| 396 | + # Wire stdout |
| 397 | + os.dup2(_stdout, 1) |
| 398 | + if _stdout > 2: |
| 399 | + os.close(_stdout) |
| 400 | + |
| 401 | + # Wire stderr (last stage only) |
| 402 | + if _stderr >= 0: |
| 403 | + os.dup2(_stderr, 2) |
| 404 | + if _stderr > 2: |
| 405 | + os.close(_stderr) |
| 406 | + |
| 407 | + try: |
| 408 | + os.execvp(cmd[0], cmd) |
| 409 | + except OSError as e: |
| 410 | + os.write(2, f"exec failed: {e}\n".encode()) |
| 411 | + os._exit(127) |
| 412 | + |
| 413 | + return _target |
| 414 | + |
| 415 | + target = _make_target(cmd, _stdin, _stdout, _stderr, |
| 416 | + _pipes, _last_stderr_w, |
| 417 | + _last_stdout_r, _last_stdout_w) |
| 418 | + |
| 419 | + sb = stage.sandbox |
| 420 | + branch = sb._setup_branch() |
| 421 | + policy = sb._effective_policy() |
| 422 | + inner_policy = dataclasses.replace(policy, close_fds=False) |
| 423 | + for attr in ('_overlay_branch', '_cow_branch'): |
| 424 | + val = getattr(policy, attr, None) |
| 425 | + if val is not None: |
| 426 | + object.__setattr__(inner_policy, attr, val) |
| 427 | + |
| 428 | + ctx = SandboxContext(target, inner_policy, sb._id) |
| 429 | + ctx.__enter__() |
| 430 | + contexts[i] = ctx |
| 431 | + |
| 432 | + # Parent: close ALL pipe fds. Data flows kernel-only. |
| 433 | + for r, w in pipes: |
| 434 | + os.close(r) |
| 435 | + os.close(w) |
| 436 | + pipes = [] # prevent double-close in error path |
| 437 | + |
| 438 | + os.close(last_stderr_w) |
| 439 | + last_stderr_w = -1 |
| 440 | + |
| 441 | + if last_stdout_w >= 0: |
| 442 | + os.close(last_stdout_w) |
| 443 | + last_stdout_w = -1 |
| 444 | + |
| 445 | + for fd in opened_fds: |
| 446 | + try: |
| 447 | + os.close(fd) |
| 448 | + except OSError: |
| 449 | + pass |
| 450 | + opened_fds = [] |
| 451 | + |
| 452 | + # Wait for all stages. Timeout applies to the whole pipeline. |
| 453 | + exit_codes = [0] * n |
| 454 | + timed_out = False |
| 455 | + for i in range(n): |
| 456 | + ctx = contexts[i] |
| 457 | + if ctx is None: |
| 458 | + exit_codes[i] = -1 |
| 459 | + continue |
| 460 | + try: |
| 461 | + exit_codes[i] = ctx.wait(timeout=timeout) |
| 462 | + except TimeoutError: |
| 463 | + timed_out = True |
| 464 | + ctx.abort() |
| 465 | + exit_codes[i] = -1 |
| 466 | + # Abort remaining stages |
| 467 | + for j in range(i + 1, n): |
| 468 | + if contexts[j] is not None: |
| 469 | + contexts[j].abort() |
| 470 | + exit_codes[j] = -1 |
| 471 | + break |
| 472 | + |
| 473 | + # Clean up contexts |
| 474 | + for i in range(n): |
| 475 | + if contexts[i] is not None: |
| 476 | + try: |
| 477 | + contexts[i].__exit__(None, None, None) |
| 478 | + except Exception: |
| 479 | + pass |
| 480 | + contexts[i] = None |
| 481 | + |
| 482 | + # Finish branches |
| 483 | + for i, stage in enumerate(stages): |
| 484 | + stage.sandbox._finish_branch(error=exit_codes[i] != 0) |
| 485 | + |
| 486 | + # Read last stage's captured output |
| 487 | + if capture_stdout and last_stdout_r >= 0: |
| 488 | + stdout_data = _read_all_fd(last_stdout_r) |
| 489 | + os.close(last_stdout_r) |
| 490 | + last_stdout_r = -1 |
| 491 | + else: |
| 492 | + stdout_data = b"" |
| 493 | + if last_stdout_r >= 0: |
| 494 | + _drain_and_close(last_stdout_r) |
| 495 | + last_stdout_r = -1 |
| 496 | + |
| 497 | + stderr_data = _read_all_fd(last_stderr_r) |
| 498 | + os.close(last_stderr_r) |
| 499 | + last_stderr_r = -1 |
| 500 | + |
| 501 | + last_exit = exit_codes[-1] |
| 502 | + if timed_out: |
| 503 | + return Result( |
| 504 | + success=False, exit_code=-1, |
| 505 | + error="Pipeline timed out", |
| 506 | + stderr=stderr_data, |
| 507 | + ) |
| 508 | + |
| 509 | + return Result( |
| 510 | + success=(last_exit == 0), |
| 511 | + exit_code=last_exit, |
| 512 | + stdout=stdout_data, |
| 513 | + stderr=stderr_data, |
| 514 | + ) |
| 515 | + |
| 516 | + except BaseException: |
| 517 | + # Cleanup on error: close all fds, abort all contexts |
| 518 | + for r, w in pipes: |
| 519 | + try: |
| 520 | + os.close(r) |
| 521 | + except OSError: |
| 522 | + pass |
| 523 | + try: |
| 524 | + os.close(w) |
| 525 | + except OSError: |
| 526 | + pass |
| 527 | + for fd in (last_stderr_r, last_stderr_w, last_stdout_r, last_stdout_w): |
| 528 | + if fd >= 0: |
| 529 | + try: |
| 530 | + os.close(fd) |
| 531 | + except OSError: |
| 532 | + pass |
| 533 | + for fd in opened_fds: |
| 534 | + try: |
| 535 | + os.close(fd) |
| 536 | + except OSError: |
| 537 | + pass |
| 538 | + for ctx in contexts: |
| 539 | + if ctx is not None: |
| 540 | + try: |
| 541 | + ctx.abort() |
| 542 | + ctx.__exit__(None, None, None) |
| 543 | + except Exception: |
| 544 | + pass |
| 545 | + for stage in stages: |
| 546 | + try: |
| 547 | + stage.sandbox._finish_branch(error=True) |
| 548 | + except Exception: |
| 549 | + pass |
| 550 | + raise |
0 commit comments