Skip to content

Commit ab0b094

Browse files
committed
Add Stage/Pipeline for piping between sandboxes
Sandbox.cmd() returns a lazy Stage. Stages chain with | into a Pipeline. Each stage runs in its own Landlock+seccomp sandbox; inter-stage data flows through kernel pipe buffers the parent never holds. This enables XOA (Execute-Only Agents): a planner with LLM access pipes a script to an executor with data access. Disjoint policies ensure untrusted data never reaches the LLM. result = ( Sandbox(planner_policy).cmd(["python3", "plan.py"]) | Sandbox(executor_policy).cmd(["python3", "-"]) ).run() Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent bed191e commit ab0b094

File tree

4 files changed

+582
-3
lines changed

4 files changed

+582
-3
lines changed

src/sandlock/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ._version import __version__
99
from .policy import Policy, FsIsolation, BranchAction, parse_ports
1010
from .sandbox import Sandbox
11-
from ._runner import Result
11+
from ._runner import Result, Stage, Pipeline
1212
from ._checkpoint import Checkpoint
1313
from ._notif_policy import NotifPolicy, NotifAction, PathRule
1414
from ._profile import load_profile, list_profiles
@@ -32,6 +32,8 @@
3232
"__version__",
3333
# Core API
3434
"Sandbox",
35+
"Stage",
36+
"Pipeline",
3537
"Policy",
3638
"FsIsolation",
3739
"BranchAction",

src/sandlock/_runner.py

Lines changed: 332 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
import os
88
import signal
99
from dataclasses import dataclass, field
10-
from typing import Any, Callable
10+
from typing import Any, Callable, TYPE_CHECKING
1111

1212
from .exceptions import SandboxError
1313
from ._context import SandboxContext
1414
from .policy import Policy
1515

16+
if TYPE_CHECKING:
17+
from .sandbox import Sandbox
18+
1619

1720
@dataclass
1821
class Result:
@@ -217,3 +220,331 @@ def _drain_and_close(fd: int) -> None:
217220
os.close(fd)
218221
except OSError:
219222
pass
223+
224+
225+
# --- Stage / Pipeline ---
226+
227+
228+
class Stage:
229+
"""A sandbox bound to a command, not yet running.
230+
231+
Created by :meth:`Sandbox.cmd`. Stages can be chained into a
232+
pipeline with the ``|`` operator::
233+
234+
result = (
235+
Sandbox(policy_a).cmd(["producer"])
236+
| Sandbox(policy_b).cmd(["consumer"])
237+
).run()
238+
"""
239+
240+
__slots__ = ("sandbox", "args")
241+
242+
def __init__(self, sandbox: Sandbox, args: list[str]):
243+
self.sandbox = sandbox
244+
self.args = args
245+
246+
def __or__(self, other: Stage | Pipeline) -> Pipeline:
247+
if isinstance(other, Pipeline):
248+
return Pipeline([self] + other.stages)
249+
if isinstance(other, Stage):
250+
return Pipeline([self, other])
251+
return NotImplemented
252+
253+
def run(self, *, timeout: float | None = None) -> Result:
254+
"""Run this single stage. Equivalent to ``sandbox.run(args)``."""
255+
return self.sandbox.run(self.args, timeout=timeout)
256+
257+
258+
class Pipeline:
259+
"""A chain of stages connected by pipes.
260+
261+
Each stage's stdout is wired to the next stage's stdin. The
262+
parent process never holds the inter-stage pipe data — it flows
263+
through kernel buffers only.
264+
265+
Created by ``stage_a | stage_b``.
266+
"""
267+
268+
__slots__ = ("stages",)
269+
270+
def __init__(self, stages: list[Stage]):
271+
if len(stages) < 2:
272+
raise ValueError("Pipeline requires at least 2 stages")
273+
self.stages = stages
274+
275+
def __or__(self, other: Stage | Pipeline) -> Pipeline:
276+
if isinstance(other, Stage):
277+
return Pipeline(self.stages + [other])
278+
if isinstance(other, Pipeline):
279+
return Pipeline(self.stages + other.stages)
280+
return NotImplemented
281+
282+
def run(
283+
self,
284+
*,
285+
stdout: int | None = None,
286+
timeout: float | None = None,
287+
) -> Result:
288+
"""Execute the pipeline.
289+
290+
Each stage runs in its own sandbox. Stdout of stage N is piped
291+
to stdin of stage N+1. The parent never reads inter-stage data.
292+
293+
Args:
294+
stdout: File descriptor for the final stage's stdout.
295+
When set, the last stage writes directly to this fd
296+
and ``result.stdout`` is empty. Use ``sys.stdout.fileno()``
297+
to send output to the terminal.
298+
timeout: Maximum seconds to wait for the entire pipeline.
299+
300+
Returns:
301+
Result from the last stage. ``stdout`` contains the last
302+
stage's output only when the *stdout* parameter is not set.
303+
"""
304+
return run_pipeline(self.stages, stdout=stdout, timeout=timeout)
305+
306+
307+
def run_pipeline(
308+
stages: list[Stage],
309+
*,
310+
stdout: int | None = None,
311+
timeout: float | None = None,
312+
) -> Result:
313+
"""Execute a pipeline of sandboxed stages connected by pipes.
314+
315+
The parent creates pipes between adjacent stages and closes its
316+
copies immediately after fork. Data flows through kernel buffers;
317+
the parent never reads inter-stage data.
318+
319+
Args:
320+
stages: Ordered list of stages to execute.
321+
stdout: File descriptor for the final stage's stdout.
322+
timeout: Maximum seconds to wait for the entire pipeline.
323+
324+
Returns:
325+
Result from the last stage.
326+
"""
327+
n = len(stages)
328+
if n == 0:
329+
raise ValueError("Pipeline requires at least 1 stage")
330+
if n == 1:
331+
return stages[0].run(timeout=timeout)
332+
333+
# Create inter-stage pipes: pipe[i] connects stage i → stage i+1
334+
# Each pipe is (read_fd, write_fd)
335+
pipes = [os.pipe() for _ in range(n - 1)]
336+
337+
# Create stderr pipe for last stage
338+
last_stderr_r, last_stderr_w = os.pipe()
339+
340+
# Capture stdout of last stage (unless caller provided an fd)
341+
capture_stdout = stdout is None
342+
if capture_stdout:
343+
last_stdout_r, last_stdout_w = os.pipe()
344+
else:
345+
last_stdout_r, last_stdout_w = -1, -1
346+
347+
contexts = [None] * n # type: list
348+
opened_fds = [] # type: list # track fds we need to clean up on error
349+
350+
try:
351+
for i, stage in enumerate(stages):
352+
# Determine this stage's stdin/stdout fds
353+
stdin_fd = pipes[i - 1][0] if i > 0 else -1
354+
if i < n - 1:
355+
stdout_fd = pipes[i][1]
356+
elif capture_stdout:
357+
stdout_fd = last_stdout_w
358+
else:
359+
stdout_fd = os.dup(stdout)
360+
opened_fds.append(stdout_fd)
361+
stderr_fd = last_stderr_w if i == n - 1 else -1
362+
363+
# Build the target function for this stage. Capture fds
364+
# and pipe list by value to avoid closure issues.
365+
cmd = stage.args
366+
_stdin = stdin_fd
367+
_stdout = stdout_fd
368+
_stderr = stderr_fd
369+
_pipes = pipes
370+
_last_stderr_w = last_stderr_w
371+
_last_stdout_r = last_stdout_r
372+
_last_stdout_w = last_stdout_w
373+
374+
def _make_target(cmd, _stdin, _stdout, _stderr,
375+
_pipes, _last_stderr_w,
376+
_last_stdout_r, _last_stdout_w):
377+
def _target():
378+
# Close all pipe fds the child doesn't need
379+
for r, w in _pipes:
380+
if r != _stdin:
381+
os.close(r)
382+
if w != _stdout:
383+
os.close(w)
384+
if _last_stderr_w >= 0 and _last_stderr_w != _stderr:
385+
os.close(_last_stderr_w)
386+
if _last_stdout_r >= 0:
387+
os.close(_last_stdout_r)
388+
if _last_stdout_w >= 0 and _last_stdout_w != _stdout:
389+
os.close(_last_stdout_w)
390+
391+
# Wire stdin
392+
if _stdin >= 0:
393+
os.dup2(_stdin, 0)
394+
os.close(_stdin)
395+
396+
# Wire stdout
397+
os.dup2(_stdout, 1)
398+
if _stdout > 2:
399+
os.close(_stdout)
400+
401+
# Wire stderr (last stage only)
402+
if _stderr >= 0:
403+
os.dup2(_stderr, 2)
404+
if _stderr > 2:
405+
os.close(_stderr)
406+
407+
try:
408+
os.execvp(cmd[0], cmd)
409+
except OSError as e:
410+
os.write(2, f"exec failed: {e}\n".encode())
411+
os._exit(127)
412+
413+
return _target
414+
415+
target = _make_target(cmd, _stdin, _stdout, _stderr,
416+
_pipes, _last_stderr_w,
417+
_last_stdout_r, _last_stdout_w)
418+
419+
sb = stage.sandbox
420+
branch = sb._setup_branch()
421+
policy = sb._effective_policy()
422+
inner_policy = dataclasses.replace(policy, close_fds=False)
423+
for attr in ('_overlay_branch', '_cow_branch'):
424+
val = getattr(policy, attr, None)
425+
if val is not None:
426+
object.__setattr__(inner_policy, attr, val)
427+
428+
ctx = SandboxContext(target, inner_policy, sb._id)
429+
ctx.__enter__()
430+
contexts[i] = ctx
431+
432+
# Parent: close ALL pipe fds. Data flows kernel-only.
433+
for r, w in pipes:
434+
os.close(r)
435+
os.close(w)
436+
pipes = [] # prevent double-close in error path
437+
438+
os.close(last_stderr_w)
439+
last_stderr_w = -1
440+
441+
if last_stdout_w >= 0:
442+
os.close(last_stdout_w)
443+
last_stdout_w = -1
444+
445+
for fd in opened_fds:
446+
try:
447+
os.close(fd)
448+
except OSError:
449+
pass
450+
opened_fds = []
451+
452+
# Wait for all stages. Timeout applies to the whole pipeline.
453+
exit_codes = [0] * n
454+
timed_out = False
455+
for i in range(n):
456+
ctx = contexts[i]
457+
if ctx is None:
458+
exit_codes[i] = -1
459+
continue
460+
try:
461+
exit_codes[i] = ctx.wait(timeout=timeout)
462+
except TimeoutError:
463+
timed_out = True
464+
ctx.abort()
465+
exit_codes[i] = -1
466+
# Abort remaining stages
467+
for j in range(i + 1, n):
468+
if contexts[j] is not None:
469+
contexts[j].abort()
470+
exit_codes[j] = -1
471+
break
472+
473+
# Clean up contexts
474+
for i in range(n):
475+
if contexts[i] is not None:
476+
try:
477+
contexts[i].__exit__(None, None, None)
478+
except Exception:
479+
pass
480+
contexts[i] = None
481+
482+
# Finish branches
483+
for i, stage in enumerate(stages):
484+
stage.sandbox._finish_branch(error=exit_codes[i] != 0)
485+
486+
# Read last stage's captured output
487+
if capture_stdout and last_stdout_r >= 0:
488+
stdout_data = _read_all_fd(last_stdout_r)
489+
os.close(last_stdout_r)
490+
last_stdout_r = -1
491+
else:
492+
stdout_data = b""
493+
if last_stdout_r >= 0:
494+
_drain_and_close(last_stdout_r)
495+
last_stdout_r = -1
496+
497+
stderr_data = _read_all_fd(last_stderr_r)
498+
os.close(last_stderr_r)
499+
last_stderr_r = -1
500+
501+
last_exit = exit_codes[-1]
502+
if timed_out:
503+
return Result(
504+
success=False, exit_code=-1,
505+
error="Pipeline timed out",
506+
stderr=stderr_data,
507+
)
508+
509+
return Result(
510+
success=(last_exit == 0),
511+
exit_code=last_exit,
512+
stdout=stdout_data,
513+
stderr=stderr_data,
514+
)
515+
516+
except BaseException:
517+
# Cleanup on error: close all fds, abort all contexts
518+
for r, w in pipes:
519+
try:
520+
os.close(r)
521+
except OSError:
522+
pass
523+
try:
524+
os.close(w)
525+
except OSError:
526+
pass
527+
for fd in (last_stderr_r, last_stderr_w, last_stdout_r, last_stdout_w):
528+
if fd >= 0:
529+
try:
530+
os.close(fd)
531+
except OSError:
532+
pass
533+
for fd in opened_fds:
534+
try:
535+
os.close(fd)
536+
except OSError:
537+
pass
538+
for ctx in contexts:
539+
if ctx is not None:
540+
try:
541+
ctx.abort()
542+
ctx.__exit__(None, None, None)
543+
except Exception:
544+
pass
545+
for stage in stages:
546+
try:
547+
stage.sandbox._finish_branch(error=True)
548+
except Exception:
549+
pass
550+
raise

0 commit comments

Comments
 (0)