Skip to content

Commit bf4f144

Browse files
Merge pull request #96 from 73ai/sandboxing-bash
feat(tools): three-tier tool safety model
2 parents 03f4e2d + 0340733 commit bf4f144

27 files changed

+1898
-94
lines changed

agent/tools/approval_rules.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,20 @@ func extractPattern(toolName string, input json.RawMessage) string {
124124
return ""
125125
}
126126
switch toolName {
127+
case "bash":
128+
if cmd, ok := m["command"]; ok {
129+
var s string
130+
if json.Unmarshal(cmd, &s) == nil {
131+
return firstToken(s)
132+
}
133+
}
134+
case "file_write", "file_edit":
135+
if p, ok := m["path"]; ok {
136+
var s string
137+
if json.Unmarshal(p, &s) == nil {
138+
return s
139+
}
140+
}
127141
case "slack_send", "slack_read_channel":
128142
if ch, ok := m["channel"]; ok {
129143
var s string

agent/tools/approval_rules_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,59 @@ func TestExtractPattern_GWSMissingCommand(t *testing.T) {
146146
t.Errorf("pattern = %q, want gws_execute (fallback)", p)
147147
}
148148
}
149+
150+
func TestExtractPattern_BashCommand(t *testing.T) {
151+
input, _ := json.Marshal(map[string]string{"command": "curl example.com"})
152+
if p := extractPattern("bash", input); p != "curl" {
153+
t.Errorf("pattern = %q, want curl", p)
154+
}
155+
}
156+
157+
func TestExtractPattern_BashSingleWord(t *testing.T) {
158+
input, _ := json.Marshal(map[string]string{"command": "ls"})
159+
if p := extractPattern("bash", input); p != "ls" {
160+
t.Errorf("pattern = %q, want ls", p)
161+
}
162+
}
163+
164+
func TestExtractPattern_FileWrite(t *testing.T) {
165+
input, _ := json.Marshal(map[string]string{"path": "/tmp/test.txt", "content": "hello"})
166+
if p := extractPattern("file_write", input); p != "/tmp/test.txt" {
167+
t.Errorf("pattern = %q, want /tmp/test.txt", p)
168+
}
169+
}
170+
171+
func TestExtractPattern_FileEdit(t *testing.T) {
172+
input, _ := json.Marshal(map[string]string{"path": "/tmp/test.txt", "old_string": "a", "new_string": "b"})
173+
if p := extractPattern("file_edit", input); p != "/tmp/test.txt" {
174+
t.Errorf("pattern = %q, want /tmp/test.txt", p)
175+
}
176+
}
177+
178+
func TestApprovalRuleSet_WildcardPattern(t *testing.T) {
179+
s := NewApprovalRuleSet()
180+
s.Add(ApprovalRule{ToolName: "bash", Pattern: ""})
181+
input, _ := json.Marshal(map[string]string{"command": "anything"})
182+
if !s.Matches("bash", input) {
183+
t.Error("empty pattern should match any input")
184+
}
185+
}
186+
187+
func TestApprovalRuleSet_DuplicateRulePrevention(t *testing.T) {
188+
s := NewApprovalRuleSet()
189+
input, _ := json.Marshal(map[string]string{"channel": "#general"})
190+
for i := 0; i < autoApproveThreshold*3; i++ {
191+
s.RecordApproval("slack_send", input)
192+
}
193+
s.mu.Lock()
194+
count := 0
195+
for _, r := range s.rules {
196+
if r.ToolName == "slack_send" && r.Pattern == "#general" {
197+
count++
198+
}
199+
}
200+
s.mu.Unlock()
201+
if count != 1 {
202+
t.Errorf("expected 1 rule, got %d (duplicate prevention failed)", count)
203+
}
204+
}

agent/tools/bash.go

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ const defaultBashTimeout = 30 * time.Second
1414

1515
// BashTool executes shell commands.
1616
type BashTool struct {
17-
timeout time.Duration
18-
filter *CommandFilter
19-
workDir string
17+
timeout time.Duration
18+
filter *CommandFilter
19+
workDir string
20+
interactor Interactor
21+
approvalRules *ApprovalRuleSet
2022
}
2123

2224
// BashOption configures a BashTool.
@@ -32,6 +34,16 @@ func WithWorkDir(dir string) BashOption {
3234
return func(b *BashTool) { b.workDir = dir }
3335
}
3436

37+
// WithInteractor sets the interactor for approval prompts.
38+
func WithInteractor(i Interactor) BashOption {
39+
return func(b *BashTool) { b.interactor = i }
40+
}
41+
42+
// WithApprovalRuleSet sets the approval rules for session-scoped auto-approve.
43+
func WithApprovalRuleSet(rules *ApprovalRuleSet) BashOption {
44+
return func(b *BashTool) { b.approvalRules = rules }
45+
}
46+
3547
// NewBashTool creates a new bash tool with the given timeout and options.
3648
func NewBashTool(timeout time.Duration, opts ...BashOption) *BashTool {
3749
if timeout == 0 {
@@ -68,6 +80,7 @@ func (b *BashTool) Execute(ctx context.Context, input json.RawMessage) (string,
6880
if err := json.Unmarshal(input, &in); err != nil {
6981
return "", fmt.Errorf("parse input: %w", err)
7082
}
83+
in.Command = strings.TrimSpace(in.Command)
7184
if in.Command == "" {
7285
return "", fmt.Errorf("command is required")
7386
}
@@ -76,14 +89,32 @@ func (b *BashTool) Execute(ctx context.Context, input json.RawMessage) (string,
7689
return "", fmt.Errorf("gws commands must use the gws_execute tool, not bash")
7790
}
7891

79-
if err := b.filter.Check(in.Command); err != nil {
80-
return "", fmt.Errorf("command blocked: %w", err)
92+
filterResult, filterErr := b.filter.CheckWithResult(in.Command)
93+
switch filterResult {
94+
case FilterDeny:
95+
if filterErr != nil {
96+
return "", fmt.Errorf("command blocked: %w", filterErr)
97+
}
98+
return "", fmt.Errorf("command blocked")
99+
case FilterPrompt:
100+
if b.interactor == nil {
101+
return "", fmt.Errorf("command blocked: no interactor for approval")
102+
}
103+
return GuardedAction(ctx, b.interactor, RiskMedium,
104+
"Run: "+in.Command,
105+
func() (string, error) { return b.runCommand(ctx, in.Command) },
106+
WithApprovalRules(b.approvalRules, "bash", input),
107+
)
81108
}
82109

110+
return b.runCommand(ctx, in.Command)
111+
}
112+
113+
func (b *BashTool) runCommand(ctx context.Context, command string) (string, error) {
83114
ctx, cancel := context.WithTimeout(ctx, b.timeout)
84115
defer cancel()
85116

86-
cmd := exec.CommandContext(ctx, "bash", "-c", in.Command)
117+
cmd := exec.CommandContext(ctx, "bash", "-c", command)
87118
if b.workDir != "" {
88119
cmd.Dir = b.workDir
89120
}

agent/tools/bash_filter.go

Lines changed: 77 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,20 @@ import (
77
"strings"
88
)
99

10+
// FilterResult indicates the outcome of a command filter check.
11+
type FilterResult int
12+
13+
const (
14+
FilterAllow FilterResult = iota // on allowlist, run freely
15+
FilterDeny // hard blocked
16+
FilterPrompt // not on allowlist, ask user
17+
)
18+
1019
// CommandFilter validates shell commands against an allowlist or blocklist.
1120
type CommandFilter struct {
12-
allowed []string // if set, only these prefixes pass
13-
blocked []string // if set, these prefixes are rejected
21+
allowed []string // if set, only these prefixes pass
22+
blocked []string // if set, these prefixes are rejected
23+
softAllow bool // if true, non-matching returns FilterPrompt instead of FilterDeny
1424
}
1525

1626
// NewAllowlistFilter creates a filter that only permits commands
@@ -19,34 +29,61 @@ func NewAllowlistFilter(prefixes []string) *CommandFilter {
1929
return &CommandFilter{allowed: prefixes}
2030
}
2131

32+
// NewSoftAllowlistFilter creates a filter that auto-allows commands on the
33+
// allowlist and returns FilterPrompt (not FilterDeny) for everything else.
34+
// Use this for interactive mode where unknown commands should be approved by the user.
35+
func NewSoftAllowlistFilter(prefixes []string) *CommandFilter {
36+
return &CommandFilter{allowed: prefixes, softAllow: true}
37+
}
38+
2239
// NewBlocklistFilter creates a filter that rejects commands
2340
// whose first token matches any of the given prefixes.
2441
func NewBlocklistFilter(prefixes []string) *CommandFilter {
2542
return &CommandFilter{blocked: prefixes}
2643
}
2744

28-
// Check validates the given command string. It splits on shell
29-
// operators (|, &&, ;, ||) and checks each segment. It also
30-
// detects command substitution via $() and backticks.
31-
func (f *CommandFilter) Check(command string) error {
45+
// CheckWithResult validates the given command string and returns a FilterResult
46+
// indicating whether to allow, deny, or prompt the user.
47+
func (f *CommandFilter) CheckWithResult(command string) (FilterResult, error) {
3248
if f == nil {
33-
return nil
49+
return FilterAllow, nil
3450
}
3551

3652
segments := splitShellSegments(command)
3753
for _, seg := range segments {
38-
if err := f.checkSegment(seg); err != nil {
39-
return err
54+
result, err := f.checkSegmentResult(seg)
55+
if err != nil || result != FilterAllow {
56+
return result, err
4057
}
4158
}
4259

43-
// Check inside $() and backtick substitutions.
4460
for _, sub := range extractSubstitutions(command) {
45-
if err := f.Check(sub); err != nil {
46-
return fmt.Errorf("in command substitution: %w", err)
61+
result, err := f.CheckWithResult(sub)
62+
if err != nil {
63+
return result, fmt.Errorf("in command substitution: %w", err)
64+
}
65+
if result != FilterAllow {
66+
return result, nil
4767
}
4868
}
4969

70+
return FilterAllow, nil
71+
}
72+
73+
// Check validates the given command string. It splits on shell
74+
// operators (|, &&, ;, ||) and checks each segment. It also
75+
// detects command substitution via $() and backticks.
76+
func (f *CommandFilter) Check(command string) error {
77+
result, err := f.CheckWithResult(command)
78+
if err != nil {
79+
return err
80+
}
81+
if result == FilterDeny {
82+
return fmt.Errorf("command not permitted")
83+
}
84+
if result == FilterPrompt {
85+
return fmt.Errorf("command requires approval")
86+
}
5087
return nil
5188
}
5289

@@ -55,41 +92,43 @@ func basename(token string) string {
5592
return filepath.Base(token)
5693
}
5794

58-
// checkSegment validates a single command segment.
59-
// Allowlist: only the first token must match.
60-
// Blocklist: every token is checked to catch wrappers like "env curl".
61-
func (f *CommandFilter) checkSegment(seg string) error {
95+
// checkSegmentResult validates a single command segment and returns a FilterResult.
96+
func (f *CommandFilter) checkSegmentResult(seg string) (FilterResult, error) {
6297
fields := strings.Fields(strings.TrimSpace(seg))
6398
if len(fields) == 0 {
64-
return nil
99+
return FilterAllow, nil
65100
}
66101
if len(f.allowed) > 0 {
67-
return f.checkToken(fields[0])
102+
return f.checkTokenResult(fields[0])
68103
}
69104
for _, tok := range fields {
70-
if err := f.checkToken(tok); err != nil {
71-
return err
105+
result, err := f.checkTokenResult(tok)
106+
if err != nil || result != FilterAllow {
107+
return result, err
72108
}
73109
}
74-
return nil
110+
return FilterAllow, nil
75111
}
76112

77-
func (f *CommandFilter) checkToken(token string) error {
113+
func (f *CommandFilter) checkTokenResult(token string) (FilterResult, error) {
78114
base := basename(token)
79115
if len(f.allowed) > 0 {
80116
for _, prefix := range f.allowed {
81117
if base == prefix {
82-
return nil
118+
return FilterAllow, nil
83119
}
84120
}
85-
return fmt.Errorf("command %q not in allowlist", token)
121+
if f.softAllow {
122+
return FilterPrompt, nil
123+
}
124+
return FilterDeny, fmt.Errorf("command %q not in allowlist", token)
86125
}
87126
for _, prefix := range f.blocked {
88127
if base == prefix {
89-
return fmt.Errorf("command %q is blocked", token)
128+
return FilterDeny, fmt.Errorf("command %q is blocked", token)
90129
}
91130
}
92-
return nil
131+
return FilterAllow, nil
93132
}
94133

95134
// splitShellSegments splits a command on |, &&, ;, and || operators.
@@ -139,7 +178,18 @@ func extractSubstitutions(cmd string) []string {
139178
return subs
140179
}
141180

142-
// DefaultBlocklist is the default set of blocked commands for interactive mode.
181+
// InteractiveAllowlist is the set of commands auto-allowed in interactive mode.
182+
// Commands not on this list require user approval (FilterPrompt).
183+
var InteractiveAllowlist = []string{
184+
"obk", "sqlite3",
185+
"ls", "cat", "head", "tail", "wc", "sort", "uniq", "diff",
186+
"find", "grep", "rg",
187+
"date", "cal", "echo", "printf",
188+
"git", "tree", "file", "stat", "jq", "which",
189+
}
190+
191+
// DefaultBlocklist is the legacy blocklist used when no Interactor is provided
192+
// (e.g. subagents). Interactive mode now uses InteractiveAllowlist instead.
143193
var DefaultBlocklist = []string{
144194
// Network
145195
"curl", "wget", "nc", "ncat", "nmap",

0 commit comments

Comments
 (0)