Skip to content

Commit 7bd4e3a

Browse files
Merge pull request #30 from priyanshujain/feat-agent-safety
Add agent safety and resilience features
2 parents 7c32762 + e806dad commit 7bd4e3a

File tree

16 files changed

+1311
-14
lines changed

16 files changed

+1311
-14
lines changed

agent/agent.go

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ type ToolExecutor interface {
1515

1616
// Agent orchestrates the conversation between user, LLM, and tools.
1717
type Agent struct {
18-
provider provider.Provider
19-
model string
20-
system string
21-
executor ToolExecutor
22-
history []provider.Message
23-
maxIter int
18+
provider provider.Provider
19+
model string
20+
system string
21+
executor ToolExecutor
22+
history []provider.Message
23+
maxIter int
24+
maxHistory int
25+
rateLimiter *provider.RateLimiter
2426
}
2527

2628
// Option configures an Agent.
@@ -36,13 +38,24 @@ func WithMaxIterations(n int) Option {
3638
return func(a *Agent) { a.maxIter = n }
3739
}
3840

41+
// WithMaxHistory sets the maximum number of history messages before compaction.
42+
func WithMaxHistory(n int) Option {
43+
return func(a *Agent) { a.maxHistory = n }
44+
}
45+
46+
// WithRateLimit sets a rate limit on LLM API calls (requests per hour).
47+
func WithRateLimit(requestsPerHour int) Option {
48+
return func(a *Agent) { a.rateLimiter = provider.NewRateLimiter(requestsPerHour) }
49+
}
50+
3951
// New creates a new Agent.
4052
func New(p provider.Provider, model string, executor ToolExecutor, opts ...Option) *Agent {
4153
a := &Agent{
42-
provider: p,
43-
model: model,
44-
executor: executor,
45-
maxIter: 25,
54+
provider: p,
55+
model: model,
56+
executor: executor,
57+
maxIter: 25,
58+
maxHistory: defaultMaxHistory,
4659
}
4760
for _, opt := range opts {
4861
opt(a)
@@ -56,6 +69,12 @@ func (a *Agent) Run(ctx context.Context, input string) (string, error) {
5669
a.history = append(a.history, provider.NewTextMessage(provider.RoleUser, input))
5770

5871
for i := range a.maxIter {
72+
a.compactHistory()
73+
if a.rateLimiter != nil {
74+
if err := a.rateLimiter.Wait(ctx); err != nil {
75+
return "", fmt.Errorf("rate limiter: %w", err)
76+
}
77+
}
5978
resp, err := a.provider.Chat(ctx, provider.ChatRequest{
6079
Model: a.model,
6180
System: a.system,
@@ -82,9 +101,9 @@ func (a *Agent) Run(ctx context.Context, input string) (string, error) {
82101
for _, call := range resp.ToolCalls() {
83102
output, err := a.executor.Execute(ctx, call)
84103
isError := err != nil
85-
content := output
104+
content := ScrubCredentials(output)
86105
if isError {
87-
content = err.Error()
106+
content = ScrubCredentials(err.Error())
88107
}
89108
results = append(results, provider.ContentBlock{
90109
Type: provider.ContentToolResult,

agent/agent_test.go

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"strings"
78
"testing"
89

910
"github.com/priyanshujain/openbotkit/provider"
@@ -161,6 +162,30 @@ func TestLoop_MultiToolSequence(t *testing.T) {
161162
}
162163
}
163164

165+
// errorExecutor returns an error for specified tools.
166+
type errorExecutor struct {
167+
successes map[string]string
168+
errors map[string]string
169+
calls []provider.ToolCall
170+
}
171+
172+
func (m *errorExecutor) Execute(_ context.Context, call provider.ToolCall) (string, error) {
173+
m.calls = append(m.calls, call)
174+
if errMsg, ok := m.errors[call.Name]; ok {
175+
return "", fmt.Errorf("%s", errMsg)
176+
}
177+
if result, ok := m.successes[call.Name]; ok {
178+
return result, nil
179+
}
180+
return "", fmt.Errorf("unknown tool %q", call.Name)
181+
}
182+
183+
func (m *errorExecutor) ToolSchemas() []provider.Tool {
184+
return []provider.Tool{
185+
{Name: "bash", Description: "Run a command", InputSchema: json.RawMessage(`{"type":"object"}`)},
186+
}
187+
}
188+
164189
func TestLoop_MaxIterations(t *testing.T) {
165190
// Provider always returns tool_use — should stop at max iterations.
166191
alwaysToolUse := &provider.ChatResponse{
@@ -188,3 +213,173 @@ func TestLoop_MaxIterations(t *testing.T) {
188213
t.Errorf("error = %q", got)
189214
}
190215
}
216+
217+
func TestLoop_ScrubsToolOutput(t *testing.T) {
218+
mp := &mockProvider{
219+
responses: []*provider.ChatResponse{
220+
{
221+
Content: []provider.ContentBlock{
222+
{Type: provider.ContentToolUse, ToolCall: &provider.ToolCall{
223+
ID: "c1", Name: "bash", Input: json.RawMessage(`{}`),
224+
}},
225+
},
226+
StopReason: provider.StopToolUse,
227+
},
228+
{
229+
Content: []provider.ContentBlock{{Type: provider.ContentText, Text: "Done"}},
230+
StopReason: provider.StopEndTurn,
231+
},
232+
},
233+
}
234+
exec := &mockExecutor{results: map[string]string{
235+
"bash": "TOKEN=sk-secret-key-12345678",
236+
}}
237+
a := New(mp, "test-model", exec)
238+
239+
_, err := a.Run(context.Background(), "show env")
240+
if err != nil {
241+
t.Fatalf("Run: %v", err)
242+
}
243+
244+
// The second request should contain the scrubbed tool result.
245+
if len(mp.requests) < 2 {
246+
t.Fatalf("expected 2 requests, got %d", len(mp.requests))
247+
}
248+
msgs := mp.requests[1].Messages
249+
// Last message should be the tool result.
250+
last := msgs[len(msgs)-1]
251+
content := last.Content[0].ToolResult.Content
252+
if strings.Contains(content, "sk-secret-key-12345678") {
253+
t.Errorf("tool output not scrubbed: %q", content)
254+
}
255+
if !strings.Contains(content, "****") {
256+
t.Errorf("expected redacted content, got: %q", content)
257+
}
258+
}
259+
260+
func TestLoop_ScrubsToolError(t *testing.T) {
261+
mp := &mockProvider{
262+
responses: []*provider.ChatResponse{
263+
{
264+
Content: []provider.ContentBlock{
265+
{Type: provider.ContentToolUse, ToolCall: &provider.ToolCall{
266+
ID: "c1", Name: "bash", Input: json.RawMessage(`{}`),
267+
}},
268+
},
269+
StopReason: provider.StopToolUse,
270+
},
271+
{
272+
Content: []provider.ContentBlock{{Type: provider.ContentText, Text: "Done"}},
273+
StopReason: provider.StopEndTurn,
274+
},
275+
},
276+
}
277+
exec := &errorExecutor{
278+
errors: map[string]string{"bash": "failed: password=supersecret123"},
279+
}
280+
a := New(mp, "test-model", exec)
281+
282+
_, err := a.Run(context.Background(), "try")
283+
if err != nil {
284+
t.Fatalf("Run: %v", err)
285+
}
286+
287+
msgs := mp.requests[1].Messages
288+
last := msgs[len(msgs)-1]
289+
content := last.Content[0].ToolResult.Content
290+
if strings.Contains(content, "supersecret123") {
291+
t.Errorf("tool error not scrubbed: %q", content)
292+
}
293+
if !last.Content[0].ToolResult.IsError {
294+
t.Error("expected IsError=true")
295+
}
296+
}
297+
298+
func TestLoop_ProviderChatError(t *testing.T) {
299+
mp := &mockProvider{responses: nil} // no responses = error on first call
300+
exec := &mockExecutor{results: map[string]string{}}
301+
a := New(mp, "test-model", exec)
302+
303+
_, err := a.Run(context.Background(), "hi")
304+
if err == nil {
305+
t.Fatal("expected error from provider")
306+
}
307+
if !strings.Contains(err.Error(), "chat (iteration 0)") {
308+
t.Errorf("error = %q, expected chat iteration error", err.Error())
309+
}
310+
}
311+
312+
func TestLoop_CompactsHistory(t *testing.T) {
313+
// Build a provider that does one tool call per iteration for 15 rounds, then ends.
314+
var responses []*provider.ChatResponse
315+
for i := range 15 {
316+
responses = append(responses, &provider.ChatResponse{
317+
Content: []provider.ContentBlock{
318+
{Type: provider.ContentToolUse, ToolCall: &provider.ToolCall{
319+
ID: fmt.Sprintf("c%d", i), Name: "bash", Input: json.RawMessage(`{}`),
320+
}},
321+
},
322+
StopReason: provider.StopToolUse,
323+
})
324+
}
325+
responses = append(responses, &provider.ChatResponse{
326+
Content: []provider.ContentBlock{{Type: provider.ContentText, Text: "Done"}},
327+
StopReason: provider.StopEndTurn,
328+
})
329+
330+
mp := &mockProvider{responses: responses}
331+
exec := &mockExecutor{results: map[string]string{"bash": "ok"}}
332+
// Without compaction, history would be 1 user + 15*(assistant+result) + final assistant = 32 messages.
333+
// With maxHistory=10, compaction fires repeatedly, keeping history bounded.
334+
a := New(mp, "test-model", exec, WithMaxHistory(10), WithMaxIterations(20))
335+
336+
_, err := a.Run(context.Background(), "go")
337+
if err != nil {
338+
t.Fatalf("Run: %v", err)
339+
}
340+
341+
// History should have been compacted (not 32 messages).
342+
if len(a.history) > 22 {
343+
t.Errorf("history not compacted: len=%d, want <=22", len(a.history))
344+
}
345+
}
346+
347+
func TestLoop_RateLimiterContextCancel(t *testing.T) {
348+
mp := &mockProvider{
349+
responses: []*provider.ChatResponse{
350+
{Content: []provider.ContentBlock{{Type: provider.ContentText, Text: "ok"}}, StopReason: provider.StopEndTurn},
351+
},
352+
}
353+
exec := &mockExecutor{results: map[string]string{}}
354+
355+
// Create agent with extremely low rate limit (1/hour).
356+
a := New(mp, "test-model", exec, WithRateLimit(1))
357+
358+
// First call uses burst, should succeed.
359+
_, err := a.Run(context.Background(), "first")
360+
if err != nil {
361+
t.Fatalf("first Run: %v", err)
362+
}
363+
364+
// Exhaust remaining burst by calling multiple times.
365+
for range 9 {
366+
mp.idx = 0
367+
mp.requests = nil
368+
a.history = nil
369+
_, _ = a.Run(context.Background(), "burst")
370+
}
371+
372+
// Now cancel context; should fail on rate limiter.
373+
ctx, cancel := context.WithCancel(context.Background())
374+
cancel()
375+
mp.idx = 0
376+
mp.requests = nil
377+
a.history = nil
378+
_, err = a.Run(ctx, "should fail")
379+
if err == nil {
380+
t.Fatal("expected rate limiter error")
381+
}
382+
if !strings.Contains(err.Error(), "rate limiter") {
383+
t.Errorf("error = %q, expected rate limiter error", err.Error())
384+
}
385+
}

agent/compact.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package agent
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/priyanshujain/openbotkit/provider"
7+
)
8+
9+
const defaultMaxHistory = 40
10+
11+
// compactHistory trims history to half of maxHistory when it exceeds maxHistory,
12+
// prepending a summary placeholder for the removed messages.
13+
func (a *Agent) compactHistory() {
14+
if len(a.history) <= a.maxHistory {
15+
return
16+
}
17+
keep := a.maxHistory / 2
18+
if keep < 1 {
19+
keep = 1
20+
}
21+
removed := len(a.history) - keep
22+
summary := provider.NewTextMessage(provider.RoleUser,
23+
fmt.Sprintf("[Earlier conversation: %d messages removed]", removed))
24+
a.history = append([]provider.Message{summary}, a.history[removed:]...)
25+
}

0 commit comments

Comments
 (0)