Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cli
import (
"context"
"fmt"
"io"
"log/slog"
"os"
"os/signal"
Expand Down Expand Up @@ -61,7 +62,7 @@ Examples:
},
},
Handler: func(inv *serpent.Invocation) error {
return Run(config, inv.Args)
return Run(config, inv.Stdin, inv.Stdout, inv.Stderr, inv.Args)
},
}
}
Expand Down Expand Up @@ -91,7 +92,7 @@ func setupLogging(logLevel string) *slog.Logger {
}

// Run executes the jail command with the given configuration and arguments
func Run(config Config, args []string) error {
func Run(config Config, stdin io.Reader, stdout, stderr io.Writer, args []string) error {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose these should go in the config

logger := setupLogging(config.LogLevel)

// Get command arguments
Expand Down Expand Up @@ -211,7 +212,14 @@ func Run(config Config, args []string) error {
// Execute command in jail
go func() {
defer cancel()
err := jailInstance.Command(args).Run()
cmd := jailInstance.Command(args)

// Inject the stdout and stderr from the serpent cli.
cmd.Stdout = stdout
cmd.Stderr = stderr
cmd.Stdin = stdin

err := cmd.Run()
if err != nil {
logger.Error("Command execution failed", "error", err)
}
Expand Down
40 changes: 18 additions & 22 deletions cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,15 @@ func (m *MockPTY) Clear() {
m.stderr = strings.Builder{}
}

func (m *MockPTY) ExpectMatch(content string) {
func (m *MockPTY) ExpectInStdout(content string) {
if !strings.Contains(m.stdout.String(), content) {
m.t.Fatalf("expected \"%s\", got: %s", content, m.stdout.String())
}
}

func (m *MockPTY) ExpectError(content string) {
func (m *MockPTY) ExpectInStderr(content string) {
if !strings.Contains(m.stderr.String(), content) {
m.t.Fatalf("expected error with \"%s\", got: %s", content, m.stderr.String())
}
}

func (m *MockPTY) RequireError() {
if m.stderr.String() == "" {
m.t.Fatal("expected error")
}
}

func (m *MockPTY) RequireNoError() {
if m.stderr.String() != "" {
m.t.Fatalf("expected nothing in stderr, but got: %s", m.stderr.String())
m.t.Fatalf("expected \"%s\", got: %s", content, m.stderr.String())
}
}

Expand All @@ -81,22 +69,30 @@ func TestPtySetupWorks(t *testing.T) {
t.Fatalf("could not run with simple --help arg: %v", err)
}

pty.RequireNoError()
pty.ExpectMatch("Monitor and restrict HTTP/HTTPS requests from processes")
pty.ExpectInStdout("Monitor and restrict HTTP/HTTPS requests from processes")
}

func TestCurlGithub(t *testing.T) {
// For these tests, I have a fixture in the form of a pastebin: https://pastebin.com/raw/2q6kyAyQ
func TestCurlPastebin(t *testing.T) {
ensureRoot(t)

cmd := NewCommand()
inv := cmd.Invoke("--allow", "\"github.com\"", "--", "curl", "https://github.com")
inv := cmd.Invoke("--allow", "\"pastebin.com\"", "--", "curl", "https://pastebin.com/raw/2q6kyAyQ")

pty := NewMockPTY(t)
pty.Attach(inv)

if err := inv.Run(); err != nil {
t.Fatalf("error curling github: %v", err)
t.Fatalf("error curling pastebin test fixture: %v", err)
}
pty.ExpectInStdout("foo")
pty.Clear()

pty.RequireNoError()
// Allowing all with a glob should allow the request
inv = cmd.Invoke("--allow", "*", "--", "curl", "https://pastebin.com/raw/2q6kyAyQ")
pty.Attach(inv)
if err := inv.Run(); err != nil {
t.Fatalf("error curling pastebin test fixture: %v", err)
}
pty.ExpectInStdout("foo")
pty.Clear()
}
5 changes: 1 addition & 4 deletions namespace/linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,6 @@ func (l *Linux) Command(command []string) *exec.Cmd {
env = append(env, fmt.Sprintf("%s=%s", key, value))
}
cmd.Env = env
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

// Use prepared process attributes from Open method
cmd.SysProcAttr = l.procAttr
Expand Down Expand Up @@ -300,4 +297,4 @@ func (l *Linux) removeNamespace() error {
return fmt.Errorf("failed to remove namespace: %v", err)
}
return nil
}
}
5 changes: 1 addition & 4 deletions namespace/macos.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,6 @@ func (m *MacOSNetJail) Command(command []string) *exec.Cmd {
env = append(env, fmt.Sprintf("%s=%s", key, value))
}
cmd.Env = env
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Stdin = os.Stdin

// Use prepared process attributes from Open method
cmd.SysProcAttr = m.procAttr
Expand Down Expand Up @@ -367,4 +364,4 @@ func (m *MacOSNetJail) cleanupTempFiles() {
if m.mainRulesPath != "" {
os.Remove(m.mainRulesPath)
}
}
}
2 changes: 1 addition & 1 deletion namespace/namespace.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ func New(config Config, logger *slog.Logger) (jail.Commander, error) {

func newNamespaceName() string {
return fmt.Sprintf("%s_%d", namespacePrefix, time.Now().UnixNano()%10000000)
}
}
Loading