Skip to content

Commit b656cf2

Browse files
Merge pull request #108 from 73ai/background-agent-refactor
Background agent improvements: retry, env scrub, tasks, budget, reactive triggers
2 parents 6069a34 + ee84c06 commit b656cf2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2313
-190
lines changed

agent/agent.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ type Agent struct {
4040
summarizer Summarizer
4141
rateLimiter *provider.RateLimiter
4242
usageRecorder UsageRecorder
43+
budgetChecker BudgetChecker
4344
}
4445

4546
// Option configures an Agent.
@@ -95,6 +96,11 @@ func WithSummarizer(s Summarizer) Option {
9596
return func(a *Agent) { a.summarizer = s }
9697
}
9798

99+
// WithBudgetChecker sets a budget checker that is called after each LLM call.
100+
func WithBudgetChecker(bc BudgetChecker) Option {
101+
return func(a *Agent) { a.budgetChecker = bc }
102+
}
103+
98104
// New creates a new Agent.
99105
func New(p provider.Provider, model string, executor ToolExecutor, opts ...Option) *Agent {
100106
a := &Agent{
@@ -143,6 +149,16 @@ func (a *Agent) Run(ctx context.Context, input string) (string, error) {
143149
if a.usageRecorder != nil {
144150
a.usageRecorder.RecordUsage(a.model, resp.Usage)
145151
}
152+
if a.budgetChecker != nil {
153+
if err := a.budgetChecker.CheckBudget(); err != nil {
154+
// Return partial text on budget exceeded (graceful).
155+
text := resp.TextContent()
156+
if text == "" {
157+
text = "(budget exceeded before completion)"
158+
}
159+
return text, err
160+
}
161+
}
146162

147163
// Append assistant response to history.
148164
a.history = append(a.history, provider.Message{

agent/agent_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,55 @@ func TestLoop_TokenBasedCompaction(t *testing.T) {
582582
}
583583
}
584584

585+
func TestLoop_BudgetExceeded(t *testing.T) {
586+
mp := &mockProvider{
587+
responses: []*provider.ChatResponse{
588+
{
589+
Content: []provider.ContentBlock{{Type: provider.ContentText, Text: "Partial result"}},
590+
StopReason: provider.StopEndTurn,
591+
Usage: provider.Usage{InputTokens: 1_000_000, OutputTokens: 1_000_000},
592+
},
593+
},
594+
}
595+
exec := &mockExecutor{results: map[string]string{}}
596+
bt := NewBudgetTracker(0.001, nil) // $0.001 budget — 1M tokens of sonnet costs $18
597+
a := New(mp, "claude-sonnet-4-6", exec, WithUsageRecorder(bt), WithBudgetChecker(bt))
598+
599+
result, err := a.Run(context.Background(), "test")
600+
if err == nil {
601+
t.Fatal("expected budget exceeded error")
602+
}
603+
if !strings.Contains(err.Error(), "budget exceeded") {
604+
t.Errorf("error = %q, want budget exceeded", err.Error())
605+
}
606+
if result != "Partial result" {
607+
t.Errorf("result = %q, want partial text returned", result)
608+
}
609+
}
610+
611+
func TestLoop_BudgetNotExceeded(t *testing.T) {
612+
mp := &mockProvider{
613+
responses: []*provider.ChatResponse{
614+
{
615+
Content: []provider.ContentBlock{{Type: provider.ContentText, Text: "Full result"}},
616+
StopReason: provider.StopEndTurn,
617+
Usage: provider.Usage{InputTokens: 100, OutputTokens: 100},
618+
},
619+
},
620+
}
621+
exec := &mockExecutor{results: map[string]string{}}
622+
bt := NewBudgetTracker(10.0, nil) // $10 budget — more than enough
623+
a := New(mp, "claude-sonnet-4-6", exec, WithUsageRecorder(bt), WithBudgetChecker(bt))
624+
625+
result, err := a.Run(context.Background(), "test")
626+
if err != nil {
627+
t.Fatalf("unexpected error: %v", err)
628+
}
629+
if result != "Full result" {
630+
t.Errorf("result = %q, want Full result", result)
631+
}
632+
}
633+
585634
func TestLoop_TracksLastInputTokens(t *testing.T) {
586635
mp := &mockProvider{
587636
responses: []*provider.ChatResponse{

agent/budget.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package agent
2+
3+
import (
4+
"fmt"
5+
"sync"
6+
7+
"github.com/73ai/openbotkit/provider"
8+
)
9+
10+
// BudgetChecker checks whether the accumulated cost has exceeded a budget.
11+
type BudgetChecker interface {
12+
CheckBudget() error
13+
}
14+
15+
// BudgetTracker wraps a UsageRecorder and accumulates cost per call.
16+
// It implements both UsageRecorder and BudgetChecker.
17+
type BudgetTracker struct {
18+
maxBudget float64
19+
inner UsageRecorder
20+
mu sync.Mutex
21+
total float64
22+
}
23+
24+
// NewBudgetTracker creates a tracker that enforces a cost budget.
25+
// If maxBudget is 0, budget checking is disabled (unlimited).
26+
func NewBudgetTracker(maxBudget float64, inner UsageRecorder) *BudgetTracker {
27+
return &BudgetTracker{maxBudget: maxBudget, inner: inner}
28+
}
29+
30+
func (bt *BudgetTracker) RecordUsage(model string, usage provider.Usage) {
31+
cost := provider.EstimateCost(model, usage)
32+
bt.mu.Lock()
33+
bt.total += cost
34+
bt.mu.Unlock()
35+
if bt.inner != nil {
36+
bt.inner.RecordUsage(model, usage)
37+
}
38+
}
39+
40+
func (bt *BudgetTracker) CheckBudget() error {
41+
if bt.maxBudget <= 0 {
42+
return nil
43+
}
44+
bt.mu.Lock()
45+
total := bt.total
46+
bt.mu.Unlock()
47+
if total >= bt.maxBudget {
48+
return fmt.Errorf("budget exceeded: $%.4f spent of $%.4f limit", total, bt.maxBudget)
49+
}
50+
return nil
51+
}
52+
53+
// Total returns the accumulated cost so far.
54+
func (bt *BudgetTracker) Total() float64 {
55+
bt.mu.Lock()
56+
defer bt.mu.Unlock()
57+
return bt.total
58+
}

agent/budget_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package agent
2+
3+
import (
4+
"testing"
5+
6+
"github.com/73ai/openbotkit/provider"
7+
)
8+
9+
type mockRecorder struct {
10+
calls int
11+
}
12+
13+
func (m *mockRecorder) RecordUsage(_ string, _ provider.Usage) {
14+
m.calls++
15+
}
16+
17+
func TestBudgetTracker_UnderBudget(t *testing.T) {
18+
bt := NewBudgetTracker(1.0, nil)
19+
bt.RecordUsage("claude-sonnet-4-6", provider.Usage{InputTokens: 1000, OutputTokens: 1000})
20+
if err := bt.CheckBudget(); err != nil {
21+
t.Errorf("unexpected error: %v", err)
22+
}
23+
}
24+
25+
func TestBudgetTracker_ExceedsBudget(t *testing.T) {
26+
bt := NewBudgetTracker(0.01, nil)
27+
// 1M tokens of sonnet = $18, well over $0.01
28+
bt.RecordUsage("claude-sonnet-4-6", provider.Usage{InputTokens: 1_000_000, OutputTokens: 1_000_000})
29+
if err := bt.CheckBudget(); err == nil {
30+
t.Error("expected budget exceeded error")
31+
}
32+
}
33+
34+
func TestBudgetTracker_Unlimited(t *testing.T) {
35+
bt := NewBudgetTracker(0, nil)
36+
bt.RecordUsage("claude-opus-4-6", provider.Usage{InputTokens: 10_000_000, OutputTokens: 10_000_000})
37+
if err := bt.CheckBudget(); err != nil {
38+
t.Errorf("unlimited budget should never error: %v", err)
39+
}
40+
}
41+
42+
func TestBudgetTracker_ChainsInnerRecorder(t *testing.T) {
43+
inner := &mockRecorder{}
44+
bt := NewBudgetTracker(1.0, inner)
45+
bt.RecordUsage("claude-sonnet-4-6", provider.Usage{InputTokens: 100})
46+
bt.RecordUsage("claude-sonnet-4-6", provider.Usage{InputTokens: 200})
47+
if inner.calls != 2 {
48+
t.Errorf("inner.calls = %d, want 2", inner.calls)
49+
}
50+
}
51+
52+
func TestBudgetTracker_Total(t *testing.T) {
53+
bt := NewBudgetTracker(1.0, nil)
54+
if bt.Total() != 0 {
55+
t.Errorf("initial total = %f, want 0", bt.Total())
56+
}
57+
bt.RecordUsage("claude-sonnet-4-6", provider.Usage{InputTokens: 1000, OutputTokens: 1000})
58+
if bt.Total() <= 0 {
59+
t.Error("expected positive total after recording usage")
60+
}
61+
}

agent/tools/agent_runner.go

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func (r *AgentRunner) buildArgs(opts runOptions) []string {
119119
}
120120

121121
func (r *AgentRunner) buildEnv() []string {
122-
env := os.Environ()
122+
env := scrubEnv(os.Environ())
123123
if r.info.Kind == AgentClaude {
124124
return filterEnv(env, "CLAUDECODE")
125125
}
@@ -137,3 +137,74 @@ func filterEnv(env []string, key string) []string {
137137
}
138138
return filtered
139139
}
140+
141+
// scrubEnv removes environment variables that contain sensitive values
142+
// (API keys, tokens, secrets, passwords) from the given list.
143+
func scrubEnv(env []string) []string {
144+
filtered := make([]string, 0, len(env))
145+
for _, e := range env {
146+
eqIdx := strings.IndexByte(e, '=')
147+
if eqIdx < 0 {
148+
filtered = append(filtered, e)
149+
continue
150+
}
151+
key := e[:eqIdx]
152+
if !isSensitiveKey(key) {
153+
filtered = append(filtered, e)
154+
}
155+
}
156+
return filtered
157+
}
158+
159+
var safeKeys = map[string]bool{
160+
"PATH": true, "HOME": true, "USER": true, "SHELL": true,
161+
"TERM": true, "LANG": true, "TMPDIR": true, "PWD": true,
162+
"GOPATH": true, "GOROOT": true, "GOBIN": true,
163+
"EDITOR": true, "VISUAL": true, "PAGER": true,
164+
"HOSTNAME": true, "LOGNAME": true, "DISPLAY": true,
165+
"COLORTERM": true, "TERM_PROGRAM": true, "SHLVL": true,
166+
}
167+
168+
var safePrefixes = []string{"LC_", "XDG_"}
169+
170+
var sensitivePrefixes = []string{
171+
"ANTHROPIC_", "OPENAI_", "GOOGLE_API_", "GEMINI_",
172+
"GROQ_", "OPENROUTER_", "AWS_SECRET_", "CLAUDECODE",
173+
}
174+
175+
var sensitiveSuffixes = []string{
176+
"_KEY", "_SECRET", "_TOKEN", "_PASSWORD", "_CREDENTIAL", "_AUTH",
177+
"_PRIVATE_KEY", "_DSN",
178+
}
179+
180+
var sensitiveExact = map[string]bool{
181+
"GITHUB_TOKEN": true, "GH_TOKEN": true,
182+
"DATABASE_URL": true, "REDIS_URL": true, "MONGODB_URI": true,
183+
"AMQP_URL": true, "ELASTICSEARCH_URL": true,
184+
}
185+
186+
func isSensitiveKey(key string) bool {
187+
if safeKeys[key] {
188+
return false
189+
}
190+
for _, p := range safePrefixes {
191+
if strings.HasPrefix(key, p) {
192+
return false
193+
}
194+
}
195+
if sensitiveExact[key] {
196+
return true
197+
}
198+
upper := strings.ToUpper(key)
199+
for _, p := range sensitivePrefixes {
200+
if strings.HasPrefix(upper, p) {
201+
return true
202+
}
203+
}
204+
for _, s := range sensitiveSuffixes {
205+
if strings.HasSuffix(upper, s) {
206+
return true
207+
}
208+
}
209+
return false
210+
}

0 commit comments

Comments
 (0)