diff --git a/cli/cli.go b/cli/cli.go index 0554f6d..f65826d 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -3,6 +3,7 @@ package cli import ( "context" "fmt" + "io" "log/slog" "os" "os/signal" @@ -59,7 +60,7 @@ Examples: }, }, Handler: func(inv *serpent.Invocation) error { - return Run(inv.Context(), config, inv.Args) + return Run(inv.Context(), inv.Stdin, inv.Stdout, inv.Stderr, config, inv.Args) }, } } @@ -89,7 +90,7 @@ func setupLogging(logLevel string) *slog.Logger { } // Run executes the jail command with the given configuration and arguments -func Run(ctx context.Context, config Config, args []string) error { +func Run(ctx context.Context, stdin io.Reader, stdout, stderr io.Writer, config Config, args []string) error { ctx, cancel := context.WithCancel(ctx) defer cancel() logger := setupLogging(config.LogLevel) @@ -155,7 +156,14 @@ func Run(ctx context.Context, config Config, args []string) error { // Execute command in jail go func() { defer cancel() - err := jailInstance.Command(args).Run() + cmd := jailInstance.Command(args) + + // Inject the stdout and stderr from the serpent cli. + cmd.Stdout = stdout + cmd.Stderr = stderr + cmd.Stdin = stdin + + err := cmd.Run() if err != nil { logger.Error("Command execution failed", "error", err) } diff --git a/cli/cli_test.go b/cli/cli_test.go index df99d67..7a7eed2 100644 --- a/cli/cli_test.go +++ b/cli/cli_test.go @@ -42,61 +42,51 @@ func (m *MockPTY) Stderr() string { } func (m *MockPTY) Clear() { - m.stdout = strings.Builder{} m.stderr = strings.Builder{} + m.stdout = strings.Builder{} } -func (m *MockPTY) ExpectMatch(content string) { +func (m *MockPTY) ExpectInStdout(content string) { if !strings.Contains(m.stdout.String(), content) { m.t.Fatalf("expected \"%s\", got: %s", content, m.stdout.String()) } } -func (m *MockPTY) ExpectError(content string) { +func (m *MockPTY) ExpectInStderr(content string) { if !strings.Contains(m.stderr.String(), content) { - m.t.Fatalf("expected error with \"%s\", got: %s", content, m.stderr.String()) - } -} - -func (m *MockPTY) RequireError() { - if m.stderr.String() == "" { - m.t.Fatal("expected error") - } -} - -func (m *MockPTY) RequireNoError() { - if m.stderr.String() != "" { - m.t.Fatalf("expected nothing in stderr, but got: %s", m.stderr.String()) + m.t.Fatalf("expected \"%s\", got: %s", content, m.stderr.String()) } } func TestPtySetupWorks(t *testing.T) { - cmd := NewCommand() - inv := cmd.Invoke("--help") - - pty := NewMockPTY(t) - pty.Attach(inv) - - if err := inv.Run(); err != nil { - t.Fatalf("could not run with simple --help arg: %v", err) - } - - pty.RequireNoError() - pty.ExpectMatch("Monitor and restrict HTTP/HTTPS requests from processes") -} - -func TestCurlGithub(t *testing.T) { - ensureRoot(t) - - cmd := NewCommand() - inv := cmd.Invoke("--allow", "\"github.com\"", "--", "curl", "https://github.com") - - pty := NewMockPTY(t) - pty.Attach(inv) - - if err := inv.Run(); err != nil { - t.Fatalf("error curling github: %v", err) - } - pty.RequireNoError() + t.Run("help command", func(t *testing.T) { + inv := NewCommand().Invoke("--help") + pty := NewMockPTY(t) + pty.Attach(inv) + if err := inv.Run(); err != nil { + t.Fatalf("could not run with simple --help arg: %v", err) + } + pty.ExpectInStdout("Monitor and restrict HTTP/HTTPS requests from processes") + }) + + t.Run("just a url", func(t *testing.T) { + inv := NewCommand().Invoke("--allow", "\"pastebin.com\"", "--", "curl", "https://pastebin.com/raw/2q6kyAyQ") + pty := NewMockPTY(t) + pty.Attach(inv) + if err := inv.Run(); err != nil { + t.Fatalf("error curling pastebin test fixture: %v", err) + } + pty.ExpectInStdout("foo") + }) + + t.Run("allow all with asterisk", func(t *testing.T) { + inv := NewCommand().Invoke("--allow", "\"*\"", "--", "curl", "https://pastebin.com/raw/2q6kyAyQ") + pty := NewMockPTY(t) + pty.Attach(inv) + if err := inv.Run(); err != nil { + t.Fatalf("error curling pastebin test fixture: %v", err) + } + pty.ExpectInStdout("foo") + }) } diff --git a/namespace/linux.go b/namespace/linux.go index accc168..502ba9c 100644 --- a/namespace/linux.go +++ b/namespace/linux.go @@ -145,9 +145,6 @@ func (l *Linux) Command(command []string) *exec.Cmd { env = append(env, fmt.Sprintf("%s=%s", key, value)) } cmd.Env = env - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr // Use prepared process attributes from Open method cmd.SysProcAttr = l.procAttr diff --git a/namespace/macos.go b/namespace/macos.go index a9db968..f54929c 100644 --- a/namespace/macos.go +++ b/namespace/macos.go @@ -146,9 +146,6 @@ func (m *MacOSNetJail) Command(command []string) *exec.Cmd { env = append(env, fmt.Sprintf("%s=%s", key, value)) } cmd.Env = env - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Stdin = os.Stdin // Use prepared process attributes from Open method cmd.SysProcAttr = m.procAttr diff --git a/rules/rules_test.go b/rules/rules_test.go index 5fbe009..829f7c3 100644 --- a/rules/rules_test.go +++ b/rules/rules_test.go @@ -160,6 +160,7 @@ func TestWildcardMatch(t *testing.T) { // Wildcard * tests {"star matches all", "*", "anything.com", true}, + {"star matches all with path", "*", "anything.com/whatever", true}, {"star matches empty", "*", "", true}, {"prefix star", "github.*", "github.com", true}, {"prefix star long", "github.*", "github.com/user/repo", true},