diff --git a/agent/agent.go b/agent/agent.go index aee77cfb..2bab9ecc 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -40,6 +40,7 @@ type Agent struct { summarizer Summarizer rateLimiter *provider.RateLimiter usageRecorder UsageRecorder + budgetChecker BudgetChecker } // Option configures an Agent. @@ -95,6 +96,11 @@ func WithSummarizer(s Summarizer) Option { return func(a *Agent) { a.summarizer = s } } +// WithBudgetChecker sets a budget checker that is called after each LLM call. +func WithBudgetChecker(bc BudgetChecker) Option { + return func(a *Agent) { a.budgetChecker = bc } +} + // New creates a new Agent. func New(p provider.Provider, model string, executor ToolExecutor, opts ...Option) *Agent { a := &Agent{ @@ -143,6 +149,16 @@ func (a *Agent) Run(ctx context.Context, input string) (string, error) { if a.usageRecorder != nil { a.usageRecorder.RecordUsage(a.model, resp.Usage) } + if a.budgetChecker != nil { + if err := a.budgetChecker.CheckBudget(); err != nil { + // Return partial text on budget exceeded (graceful). + text := resp.TextContent() + if text == "" { + text = "(budget exceeded before completion)" + } + return text, err + } + } // Append assistant response to history. a.history = append(a.history, provider.Message{ diff --git a/agent/agent_test.go b/agent/agent_test.go index 1a1deecf..a27d1f00 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -582,6 +582,55 @@ func TestLoop_TokenBasedCompaction(t *testing.T) { } } +func TestLoop_BudgetExceeded(t *testing.T) { + mp := &mockProvider{ + responses: []*provider.ChatResponse{ + { + Content: []provider.ContentBlock{{Type: provider.ContentText, Text: "Partial result"}}, + StopReason: provider.StopEndTurn, + Usage: provider.Usage{InputTokens: 1_000_000, OutputTokens: 1_000_000}, + }, + }, + } + exec := &mockExecutor{results: map[string]string{}} + bt := NewBudgetTracker(0.001, nil) // $0.001 budget — 1M tokens of sonnet costs $18 + a := New(mp, "claude-sonnet-4-6", exec, WithUsageRecorder(bt), WithBudgetChecker(bt)) + + result, err := a.Run(context.Background(), "test") + if err == nil { + t.Fatal("expected budget exceeded error") + } + if !strings.Contains(err.Error(), "budget exceeded") { + t.Errorf("error = %q, want budget exceeded", err.Error()) + } + if result != "Partial result" { + t.Errorf("result = %q, want partial text returned", result) + } +} + +func TestLoop_BudgetNotExceeded(t *testing.T) { + mp := &mockProvider{ + responses: []*provider.ChatResponse{ + { + Content: []provider.ContentBlock{{Type: provider.ContentText, Text: "Full result"}}, + StopReason: provider.StopEndTurn, + Usage: provider.Usage{InputTokens: 100, OutputTokens: 100}, + }, + }, + } + exec := &mockExecutor{results: map[string]string{}} + bt := NewBudgetTracker(10.0, nil) // $10 budget — more than enough + a := New(mp, "claude-sonnet-4-6", exec, WithUsageRecorder(bt), WithBudgetChecker(bt)) + + result, err := a.Run(context.Background(), "test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "Full result" { + t.Errorf("result = %q, want Full result", result) + } +} + func TestLoop_TracksLastInputTokens(t *testing.T) { mp := &mockProvider{ responses: []*provider.ChatResponse{ diff --git a/agent/budget.go b/agent/budget.go new file mode 100644 index 00000000..8766a1fa --- /dev/null +++ b/agent/budget.go @@ -0,0 +1,58 @@ +package agent + +import ( + "fmt" + "sync" + + "github.com/73ai/openbotkit/provider" +) + +// BudgetChecker checks whether the accumulated cost has exceeded a budget. +type BudgetChecker interface { + CheckBudget() error +} + +// BudgetTracker wraps a UsageRecorder and accumulates cost per call. +// It implements both UsageRecorder and BudgetChecker. +type BudgetTracker struct { + maxBudget float64 + inner UsageRecorder + mu sync.Mutex + total float64 +} + +// NewBudgetTracker creates a tracker that enforces a cost budget. +// If maxBudget is 0, budget checking is disabled (unlimited). +func NewBudgetTracker(maxBudget float64, inner UsageRecorder) *BudgetTracker { + return &BudgetTracker{maxBudget: maxBudget, inner: inner} +} + +func (bt *BudgetTracker) RecordUsage(model string, usage provider.Usage) { + cost := provider.EstimateCost(model, usage) + bt.mu.Lock() + bt.total += cost + bt.mu.Unlock() + if bt.inner != nil { + bt.inner.RecordUsage(model, usage) + } +} + +func (bt *BudgetTracker) CheckBudget() error { + if bt.maxBudget <= 0 { + return nil + } + bt.mu.Lock() + total := bt.total + bt.mu.Unlock() + if total >= bt.maxBudget { + return fmt.Errorf("budget exceeded: $%.4f spent of $%.4f limit", total, bt.maxBudget) + } + return nil +} + +// Total returns the accumulated cost so far. +func (bt *BudgetTracker) Total() float64 { + bt.mu.Lock() + defer bt.mu.Unlock() + return bt.total +} diff --git a/agent/budget_test.go b/agent/budget_test.go new file mode 100644 index 00000000..b03a9cda --- /dev/null +++ b/agent/budget_test.go @@ -0,0 +1,61 @@ +package agent + +import ( + "testing" + + "github.com/73ai/openbotkit/provider" +) + +type mockRecorder struct { + calls int +} + +func (m *mockRecorder) RecordUsage(_ string, _ provider.Usage) { + m.calls++ +} + +func TestBudgetTracker_UnderBudget(t *testing.T) { + bt := NewBudgetTracker(1.0, nil) + bt.RecordUsage("claude-sonnet-4-6", provider.Usage{InputTokens: 1000, OutputTokens: 1000}) + if err := bt.CheckBudget(); err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestBudgetTracker_ExceedsBudget(t *testing.T) { + bt := NewBudgetTracker(0.01, nil) + // 1M tokens of sonnet = $18, well over $0.01 + bt.RecordUsage("claude-sonnet-4-6", provider.Usage{InputTokens: 1_000_000, OutputTokens: 1_000_000}) + if err := bt.CheckBudget(); err == nil { + t.Error("expected budget exceeded error") + } +} + +func TestBudgetTracker_Unlimited(t *testing.T) { + bt := NewBudgetTracker(0, nil) + bt.RecordUsage("claude-opus-4-6", provider.Usage{InputTokens: 10_000_000, OutputTokens: 10_000_000}) + if err := bt.CheckBudget(); err != nil { + t.Errorf("unlimited budget should never error: %v", err) + } +} + +func TestBudgetTracker_ChainsInnerRecorder(t *testing.T) { + inner := &mockRecorder{} + bt := NewBudgetTracker(1.0, inner) + bt.RecordUsage("claude-sonnet-4-6", provider.Usage{InputTokens: 100}) + bt.RecordUsage("claude-sonnet-4-6", provider.Usage{InputTokens: 200}) + if inner.calls != 2 { + t.Errorf("inner.calls = %d, want 2", inner.calls) + } +} + +func TestBudgetTracker_Total(t *testing.T) { + bt := NewBudgetTracker(1.0, nil) + if bt.Total() != 0 { + t.Errorf("initial total = %f, want 0", bt.Total()) + } + bt.RecordUsage("claude-sonnet-4-6", provider.Usage{InputTokens: 1000, OutputTokens: 1000}) + if bt.Total() <= 0 { + t.Error("expected positive total after recording usage") + } +} diff --git a/agent/tools/agent_runner.go b/agent/tools/agent_runner.go index 0410e51d..11aa091f 100644 --- a/agent/tools/agent_runner.go +++ b/agent/tools/agent_runner.go @@ -119,7 +119,7 @@ func (r *AgentRunner) buildArgs(opts runOptions) []string { } func (r *AgentRunner) buildEnv() []string { - env := os.Environ() + env := scrubEnv(os.Environ()) if r.info.Kind == AgentClaude { return filterEnv(env, "CLAUDECODE") } @@ -137,3 +137,74 @@ func filterEnv(env []string, key string) []string { } return filtered } + +// scrubEnv removes environment variables that contain sensitive values +// (API keys, tokens, secrets, passwords) from the given list. +func scrubEnv(env []string) []string { + filtered := make([]string, 0, len(env)) + for _, e := range env { + eqIdx := strings.IndexByte(e, '=') + if eqIdx < 0 { + filtered = append(filtered, e) + continue + } + key := e[:eqIdx] + if !isSensitiveKey(key) { + filtered = append(filtered, e) + } + } + return filtered +} + +var safeKeys = map[string]bool{ + "PATH": true, "HOME": true, "USER": true, "SHELL": true, + "TERM": true, "LANG": true, "TMPDIR": true, "PWD": true, + "GOPATH": true, "GOROOT": true, "GOBIN": true, + "EDITOR": true, "VISUAL": true, "PAGER": true, + "HOSTNAME": true, "LOGNAME": true, "DISPLAY": true, + "COLORTERM": true, "TERM_PROGRAM": true, "SHLVL": true, +} + +var safePrefixes = []string{"LC_", "XDG_"} + +var sensitivePrefixes = []string{ + "ANTHROPIC_", "OPENAI_", "GOOGLE_API_", "GEMINI_", + "GROQ_", "OPENROUTER_", "AWS_SECRET_", "CLAUDECODE", +} + +var sensitiveSuffixes = []string{ + "_KEY", "_SECRET", "_TOKEN", "_PASSWORD", "_CREDENTIAL", "_AUTH", + "_PRIVATE_KEY", "_DSN", +} + +var sensitiveExact = map[string]bool{ + "GITHUB_TOKEN": true, "GH_TOKEN": true, + "DATABASE_URL": true, "REDIS_URL": true, "MONGODB_URI": true, + "AMQP_URL": true, "ELASTICSEARCH_URL": true, +} + +func isSensitiveKey(key string) bool { + if safeKeys[key] { + return false + } + for _, p := range safePrefixes { + if strings.HasPrefix(key, p) { + return false + } + } + if sensitiveExact[key] { + return true + } + upper := strings.ToUpper(key) + for _, p := range sensitivePrefixes { + if strings.HasPrefix(upper, p) { + return true + } + } + for _, s := range sensitiveSuffixes { + if strings.HasSuffix(upper, s) { + return true + } + } + return false +} diff --git a/agent/tools/agent_runner_test.go b/agent/tools/agent_runner_test.go index 59d44d45..6abca660 100644 --- a/agent/tools/agent_runner_test.go +++ b/agent/tools/agent_runner_test.go @@ -134,36 +134,28 @@ func TestAgentRunner_StripsCLAUDECODE(t *testing.T) { } } -func TestAgentRunner_GeminiKeepsCLAUDECODE(t *testing.T) { - t.Setenv("CLAUDECODE", "1") +func TestAgentRunner_GeminiScrubsSecrets(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "sk-test") + t.Setenv("ANTHROPIC_API_KEY", "sk-ant-test") r := NewAgentRunner(AgentInfo{Kind: AgentGemini, Binary: "gemini"}) env := r.buildEnv() - found := false for _, e := range env { - if e == "CLAUDECODE=1" { - found = true - break + if strings.HasPrefix(e, "OPENAI_API_KEY=") || strings.HasPrefix(e, "ANTHROPIC_API_KEY=") { + t.Errorf("sensitive var not scrubbed: %s", e) } } - if !found { - t.Error("CLAUDECODE should NOT be stripped for gemini") - } } -func TestAgentRunner_CodexKeepsCLAUDECODE(t *testing.T) { - t.Setenv("CLAUDECODE", "1") +func TestAgentRunner_CodexScrubsSecrets(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "sk-test") + t.Setenv("GITHUB_TOKEN", "ghp_test") r := NewAgentRunner(AgentInfo{Kind: AgentCodex, Binary: "codex"}) env := r.buildEnv() - found := false for _, e := range env { - if e == "CLAUDECODE=1" { - found = true - break + if strings.HasPrefix(e, "OPENAI_API_KEY=") || strings.HasPrefix(e, "GITHUB_TOKEN=") { + t.Errorf("sensitive var not scrubbed: %s", e) } } - if !found { - t.Error("CLAUDECODE should NOT be stripped for codex") - } } func TestAgentRunner_Timeout(t *testing.T) { @@ -187,6 +179,87 @@ func TestFilterEnv(t *testing.T) { } } +func TestIsSensitiveKey(t *testing.T) { + tests := []struct { + key string + sensitive bool + }{ + {"PATH", false}, + {"HOME", false}, + {"USER", false}, + {"TERM", false}, + {"GOPATH", false}, + {"LC_ALL", false}, + {"XDG_CONFIG_HOME", false}, + {"ANTHROPIC_API_KEY", true}, + {"OPENAI_API_KEY", true}, + {"GOOGLE_API_KEY", true}, + {"GEMINI_API_KEY", true}, + {"GROQ_API_KEY", true}, + {"OPENROUTER_API_KEY", true}, + {"AWS_SECRET_ACCESS_KEY", true}, + {"GITHUB_TOKEN", true}, + {"GH_TOKEN", true}, + {"MY_SECRET", true}, + {"DB_PASSWORD", true}, + {"AUTH_TOKEN", true}, + {"SERVICE_CREDENTIAL", true}, + {"SOME_AUTH", true}, + {"CLAUDECODE", true}, + {"DATABASE_URL", true}, + {"REDIS_URL", true}, + {"MONGODB_URI", true}, + {"TLS_PRIVATE_KEY", true}, + {"PG_DSN", true}, + {"RANDOM_VAR", false}, + {"EDITOR", false}, + } + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + if got := isSensitiveKey(tt.key); got != tt.sensitive { + t.Errorf("isSensitiveKey(%q) = %v, want %v", tt.key, got, tt.sensitive) + } + }) + } +} + +func TestScrubEnv_RemovesSensitiveKeys(t *testing.T) { + env := []string{ + "PATH=/usr/bin", + "HOME=/home/user", + "ANTHROPIC_API_KEY=sk-ant-123", + "OPENAI_API_KEY=sk-123", + "GITHUB_TOKEN=ghp_abc", + "MY_SECRET=s3cret", + "TERM=xterm", + } + got := scrubEnv(env) + allowed := map[string]bool{"PATH=/usr/bin": true, "HOME=/home/user": true, "TERM=xterm": true} + if len(got) != len(allowed) { + t.Fatalf("got %d entries, want %d: %v", len(got), len(allowed), got) + } + for _, e := range got { + if !allowed[e] { + t.Errorf("unexpected env var: %s", e) + } + } +} + +func TestScrubEnv_KeepsSafeVars(t *testing.T) { + env := []string{ + "PATH=/usr/bin", + "HOME=/home/user", + "GOPATH=/go", + "LC_ALL=en_US.UTF-8", + "XDG_CONFIG_HOME=/home/user/.config", + "EDITOR=vim", + } + got := scrubEnv(env) + if len(got) != len(env) { + t.Errorf("scrubEnv removed safe vars: got %d, want %d", len(got), len(env)) + } +} + func TestAgentRunner_RealClaude(t *testing.T) { if _, err := exec.LookPath("claude"); err != nil { t.Skip("claude not on PATH") diff --git a/agent/tools/agent_stream.go b/agent/tools/agent_stream.go index b57715df..d9ddbff9 100644 --- a/agent/tools/agent_stream.go +++ b/agent/tools/agent_stream.go @@ -121,7 +121,7 @@ func (r *StreamRunner) buildStreamArgs(opts runOptions) []string { } func (r *StreamRunner) buildEnv() []string { - env := os.Environ() + env := scrubEnv(os.Environ()) if r.info.Kind == AgentClaude { return filterEnv(env, "CLAUDECODE") } diff --git a/agent/tools/delegate_task.go b/agent/tools/delegate_task.go index 18252285..c7160738 100644 --- a/agent/tools/delegate_task.go +++ b/agent/tools/delegate_task.go @@ -11,6 +11,8 @@ import ( "path/filepath" "strings" "time" + + "github.com/73ai/openbotkit/agent/audit" ) const defaultDelegateTimeout = 5 * time.Minute @@ -23,6 +25,7 @@ type DelegateTaskConfig struct { Tracker *TaskTracker // nil = sync-only (Phase 1 behavior) ApprovalRules *ApprovalRuleSet ScratchDir string // directory for writing result files + AuditLogger *audit.Logger } const progressThrottle = 30 * time.Second @@ -37,6 +40,7 @@ type DelegateTaskTool struct { tracker *TaskTracker approvalRules *ApprovalRuleSet scratchDir string + auditLogger *audit.Logger } // NewDelegateTaskTool creates a new delegate_task tool. @@ -60,6 +64,7 @@ func NewDelegateTaskTool(cfg DelegateTaskConfig) *DelegateTaskTool { tracker: cfg.Tracker, approvalRules: cfg.ApprovalRules, scratchDir: cfg.ScratchDir, + auditLogger: cfg.AuditLogger, } } @@ -237,11 +242,13 @@ func (d *DelegateTaskTool) runAsync( output, err := sr.RunStream(ctx, task, d.timeout, onEvent, runOpts...) if err != nil { d.tracker.Fail(taskID, err.Error()) + d.logAudit(taskID, preview, "", err.Error()) d.interactor.Notify(fmt.Sprintf("Task failed: %s. %s", preview, err)) return } summary := d.writeResultFile(output, task) d.tracker.Complete(taskID, summary) + d.logAudit(taskID, preview, summary, "") d.interactor.Notify(fmt.Sprintf("Task completed: %s. Use check_task to see results.", preview)) return } @@ -250,11 +257,13 @@ func (d *DelegateTaskTool) runAsync( output, err := runner.Run(ctx, task, d.timeout, runOpts...) if err != nil { d.tracker.Fail(taskID, err.Error()) + d.logAudit(taskID, preview, "", err.Error()) d.interactor.Notify(fmt.Sprintf("Task failed: %s. %s", preview, err)) return } summary := d.writeResultFile(output, task) d.tracker.Complete(taskID, summary) + d.logAudit(taskID, preview, summary, "") d.interactor.Notify(fmt.Sprintf("Task completed: %s. Use check_task to see results.", preview)) } @@ -305,6 +314,19 @@ func (d *DelegateTaskTool) selectRunner(agent string) (AgentRunnerInterface, Age return runner, kind, nil } +func (d *DelegateTaskTool) logAudit(taskID, input, output, errMsg string) { + if d.auditLogger == nil { + return + } + d.auditLogger.Log(audit.Entry{ + Context: "delegated", + ToolName: "delegate_task", + InputSummary: taskID + ": " + input, + OutputSummary: output, + Error: errMsg, + }) +} + func (d *DelegateTaskTool) agentEnumJSON() string { names := make([]string, len(d.agents)) for i, a := range d.agents { diff --git a/agent/tools/schedule.go b/agent/tools/schedule.go index bb675228..7a8708ec 100644 --- a/agent/tools/schedule.go +++ b/agent/tools/schedule.go @@ -50,7 +50,7 @@ func NewCreateScheduleTool(deps ScheduleToolDeps) *CreateScheduleTool { func (t *CreateScheduleTool) Name() string { return "create_schedule" } func (t *CreateScheduleTool) Description() string { - return "Create a recurring or one-shot scheduled task" + return "Create a recurring, one-shot, or reactive scheduled task" } func (t *CreateScheduleTool) InputSchema() json.RawMessage { return json.RawMessage(`{ @@ -58,7 +58,7 @@ func (t *CreateScheduleTool) InputSchema() json.RawMessage { "properties": { "type": { "type": "string", - "enum": ["recurring", "one_shot"], + "enum": ["recurring", "one_shot", "reactive"], "description": "Schedule type" }, "cron_expr": { @@ -80,6 +80,24 @@ func (t *CreateScheduleTool) InputSchema() json.RawMessage { "description": { "type": "string", "description": "Human-readable description of the schedule" + }, + "model_tier": { + "type": "string", + "enum": ["fast", "default", "complex", "nano"], + "description": "Model tier for the task (default: fast)" + }, + "max_budget_usd": { + "type": "number", + "description": "Maximum cost budget in USD (0 = unlimited)" + }, + "trigger_source": { + "type": "string", + "enum": ["gmail", "whatsapp", "imessage", "applenotes"], + "description": "Data source to watch (for reactive type)" + }, + "trigger_query": { + "type": "string", + "description": "SQL WHERE clause to match new rows (for reactive type)" } }, "required": ["type", "task", "timezone"] @@ -87,12 +105,16 @@ func (t *CreateScheduleTool) InputSchema() json.RawMessage { } type createScheduleInput struct { - Type string `json:"type"` - CronExpr string `json:"cron_expr"` - ScheduledAt string `json:"scheduled_at"` - Task string `json:"task"` - Timezone string `json:"timezone"` - Description string `json:"description"` + Type string `json:"type"` + CronExpr string `json:"cron_expr"` + ScheduledAt string `json:"scheduled_at"` + Task string `json:"task"` + Timezone string `json:"timezone"` + Description string `json:"description"` + ModelTier string `json:"model_tier"` + MaxBudgetUSD float64 `json:"max_budget_usd"` + TriggerSource string `json:"trigger_source"` + TriggerQuery string `json:"trigger_query"` } func (t *CreateScheduleTool) Execute(_ context.Context, input json.RawMessage) (string, error) { @@ -110,13 +132,20 @@ func (t *CreateScheduleTool) Execute(_ context.Context, input json.RawMessage) ( return "", fmt.Errorf("invalid timezone %q: %w", in.Timezone, err) } + modelTier := in.ModelTier + if modelTier == "" { + modelTier = "fast" + } + s := &scheduler.Schedule{ - Type: scheduler.ScheduleType(in.Type), - Task: in.Task, - Channel: t.deps.Channel, - ChannelMeta: t.deps.ChannelMeta, - Timezone: in.Timezone, - Description: in.Description, + Type: scheduler.ScheduleType(in.Type), + Task: in.Task, + Channel: t.deps.Channel, + ChannelMeta: t.deps.ChannelMeta, + Timezone: in.Timezone, + Description: in.Description, + ModelTier: modelTier, + MaxBudgetUSD: in.MaxBudgetUSD, } switch s.Type { @@ -150,8 +179,21 @@ func (t *CreateScheduleTool) Execute(_ context.Context, input json.RawMessage) ( } s.ScheduledAt = &scheduledAt + case scheduler.Reactive: + if in.TriggerSource == "" { + return "", fmt.Errorf("trigger_source is required for reactive schedules") + } + if err := scheduler.ValidateTriggerSource(in.TriggerSource); err != nil { + return "", err + } + if err := scheduler.ValidateTriggerQuery(in.TriggerSource, in.TriggerQuery); err != nil { + return "", err + } + s.TriggerSource = in.TriggerSource + s.TriggerQuery = in.TriggerQuery + default: - return "", fmt.Errorf("type must be 'recurring' or 'one_shot'") + return "", fmt.Errorf("type must be 'recurring', 'one_shot', or 'reactive'") } db, err := t.deps.openDB() @@ -165,6 +207,9 @@ func (t *CreateScheduleTool) Execute(_ context.Context, input json.RawMessage) ( return "", fmt.Errorf("create schedule: %w", err) } + if s.Type == scheduler.Reactive { + return fmt.Sprintf("Reactive schedule created (ID: %d). Watching %s for: %s. Description: %s.", id, in.TriggerSource, in.TriggerQuery, in.Description), nil + } nextRun := t.formatNextRun(s, loc) return fmt.Sprintf("Schedule created (ID: %d). Description: %s. Next run: %s (in your timezone).", id, in.Description, nextRun), nil } @@ -227,7 +272,11 @@ func (t *ListSchedulesTool) Execute(_ context.Context, _ json.RawMessage) (strin status = "disabled" } - fmt.Fprintf(&b, "ID: %d | %s | %s | %s\n", s.ID, s.Type, status, s.Description) + fmt.Fprintf(&b, "ID: %d | %s | %s | %s | tier=%s", s.ID, s.Type, status, s.Description, s.ModelTier) + if s.MaxBudgetUSD > 0 { + fmt.Fprintf(&b, " | budget=$%.2f", s.MaxBudgetUSD) + } + fmt.Fprintln(&b) if s.Type == scheduler.Recurring { fmt.Fprintf(&b, " Cron: %s (UTC)\n", s.CronExpr) parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow) @@ -239,6 +288,9 @@ func (t *ListSchedulesTool) Execute(_ context.Context, _ json.RawMessage) (strin if s.Type == scheduler.OneShot && s.ScheduledAt != nil { fmt.Fprintf(&b, " Scheduled at: %s\n", s.ScheduledAt.In(loc).Format("2006-01-02 15:04 MST")) } + if s.Type == scheduler.Reactive { + fmt.Fprintf(&b, " Source: %s | Query: %s\n", s.TriggerSource, s.TriggerQuery) + } if s.LastRunAt != nil { fmt.Fprintf(&b, " Last run: %s\n", s.LastRunAt.In(loc).Format("2006-01-02 15:04 MST")) } diff --git a/agent/tools/task_tracker.go b/agent/tools/task_tracker.go index ba9f7524..36054171 100644 --- a/agent/tools/task_tracker.go +++ b/agent/tools/task_tracker.go @@ -2,8 +2,12 @@ package tools import ( "fmt" + "log/slog" "sync" "time" + + "github.com/73ai/openbotkit/service/tasks" + "github.com/73ai/openbotkit/store" ) // TaskStatus represents the state of a delegated task. @@ -35,6 +39,7 @@ type TaskTracker struct { tasks map[string]*TaskRecord order []string // insertion order for deterministic listing maxConcurrent int + db *store.DB // nil for in-memory only } // NewTaskTracker creates a tracker with default max concurrency of 3. @@ -45,6 +50,41 @@ func NewTaskTracker() *TaskTracker { } } +// NewPersistentTaskTracker creates a tracker backed by a database. +// It migrates the schema, runs cleanup, and loads existing running tasks. +func NewPersistentTaskTracker(db *store.DB) *TaskTracker { + if err := tasks.Migrate(db); err != nil { + slog.Warn("tasks: migrate failed", "error", err) + } + tasks.Cleanup(db) + + t := &TaskTracker{ + tasks: make(map[string]*TaskRecord), + maxConcurrent: defaultMaxConcurrent, + db: db, + } + return t +} + +// OpenPersistentTaskTracker opens a DB and creates a persistent tracker. +// Falls back to in-memory on DB error. +func OpenPersistentTaskTracker(driver, dsn string) *TaskTracker { + db, err := store.Open(store.Config{Driver: driver, DSN: dsn}) + if err != nil { + slog.Warn("tasks: open db failed, using in-memory tracker", "error", err) + return NewTaskTracker() + } + return NewPersistentTaskTracker(db) +} + +// Close closes the underlying database connection if persistent. +func (t *TaskTracker) Close() error { + if t.db != nil { + return t.db.Close() + } + return nil +} + // Start registers a new running task. Returns error if at max concurrent. func (t *TaskTracker) Start(id, task string, agent AgentKind) error { t.mu.Lock() @@ -52,14 +92,23 @@ func (t *TaskTracker) Start(id, task string, agent AgentKind) error { if t.runningCountLocked() >= t.maxConcurrent { return fmt.Errorf("too many concurrent tasks (max %d)", t.maxConcurrent) } + now := time.Now() t.tasks[id] = &TaskRecord{ ID: id, Task: task, Agent: agent, Status: TaskRunning, - StartedAt: time.Now(), + StartedAt: now, } t.order = append(t.order, id) + if t.db != nil { + if err := tasks.Insert(t.db, &tasks.TaskRecord{ + ID: id, Task: task, Agent: string(agent), + Status: "running", StartedAt: now, + }); err != nil { + slog.Warn("tasks: db insert failed", "id", id, "error", err) + } + } return nil } @@ -72,6 +121,11 @@ func (t *TaskTracker) Complete(id, output string) { rec.Output = output rec.DoneAt = time.Now() } + if t.db != nil { + if err := tasks.SetCompleted(t.db, id, output); err != nil { + slog.Warn("tasks: db set completed failed", "id", id, "error", err) + } + } } // Fail marks a task as failed with an error message. @@ -83,24 +137,51 @@ func (t *TaskTracker) Fail(id, errMsg string) { rec.Error = errMsg rec.DoneAt = time.Now() } + if t.db != nil { + if err := tasks.SetFailed(t.db, id, errMsg); err != nil { + slog.Warn("tasks: db set failed", "id", id, "error", err) + } + } } -// Get returns a task record by ID. +// Get returns a task record by ID. Falls through to DB for cross-session lookup. func (t *TaskTracker) Get(id string) (*TaskRecord, bool) { t.mu.Lock() defer t.mu.Unlock() rec, ok := t.tasks[id] - if !ok { - return nil, false + if ok { + copy := *rec + return ©, true } - copy := *rec - return ©, true + if t.db != nil { + dbRec, err := tasks.Get(t.db, id) + if err != nil { + slog.Warn("tasks: db get failed", "id", id, "error", err) + return nil, false + } + if dbRec != nil { + return dbTaskToRecord(dbRec), true + } + } + return nil, false } -// List returns all tasks in insertion order. +// List returns all tasks. When DB is available, returns full cross-session view. func (t *TaskTracker) List() []*TaskRecord { t.mu.Lock() defer t.mu.Unlock() + if t.db != nil { + dbRecs, err := tasks.List(t.db) + if err != nil { + slog.Warn("tasks: db list failed", "error", err) + } else { + result := make([]*TaskRecord, 0, len(dbRecs)) + for _, r := range dbRecs { + result = append(result, dbTaskToRecord(r)) + } + return result + } + } result := make([]*TaskRecord, 0, len(t.order)) for _, id := range t.order { if rec, ok := t.tasks[id]; ok { @@ -111,6 +192,22 @@ func (t *TaskTracker) List() []*TaskRecord { return result } +func dbTaskToRecord(r *tasks.TaskRecord) *TaskRecord { + rec := &TaskRecord{ + ID: r.ID, + Task: r.Task, + Agent: AgentKind(r.Agent), + Status: TaskStatus(r.Status), + StartedAt: r.StartedAt, + Output: r.Output, + Error: r.Error, + } + if r.DoneAt != nil { + rec.DoneAt = *r.DoneAt + } + return rec +} + // RunningCount returns the number of currently running tasks. func (t *TaskTracker) RunningCount() int { t.mu.Lock() diff --git a/agent/tools/task_tracker_test.go b/agent/tools/task_tracker_test.go index 2b4da4ac..e9fcdaed 100644 --- a/agent/tools/task_tracker_test.go +++ b/agent/tools/task_tracker_test.go @@ -3,6 +3,9 @@ package tools import ( "sync" "testing" + + "github.com/73ai/openbotkit/service/tasks" + "github.com/73ai/openbotkit/store" ) func TestTaskTracker_StartAndGet(t *testing.T) { @@ -178,3 +181,87 @@ func TestTaskTracker_FailNonexistent(t *testing.T) { tr := NewTaskTracker() tr.Fail("ghost", "error") // should not panic } + +func openTestTrackerDB(t *testing.T) *store.DB { + t.Helper() + db, err := store.Open(store.SQLiteConfig(":memory:")) + if err != nil { + t.Fatalf("open db: %v", err) + } + t.Cleanup(func() { db.Close() }) + return db +} + +func TestPersistentTaskTracker_CrossSessionGet(t *testing.T) { + db := openTestTrackerDB(t) + + // Tracker 1: create and complete a task. + tr1 := NewPersistentTaskTracker(db) + tr1.Start("t1", "research", AgentClaude) + tr1.Complete("t1", "result from session 1") + + // Tracker 2: simulates a new session — same DB, empty memory. + tr2 := &TaskTracker{ + tasks: make(map[string]*TaskRecord), + maxConcurrent: defaultMaxConcurrent, + db: db, + } + + rec, ok := tr2.Get("t1") + if !ok { + t.Fatal("cross-session Get should find task in DB") + } + if rec.Status != TaskCompleted { + t.Errorf("Status = %q, want completed", rec.Status) + } + if rec.Output != "result from session 1" { + t.Errorf("Output = %q", rec.Output) + } +} + +func TestPersistentTaskTracker_CrossSessionList(t *testing.T) { + db := openTestTrackerDB(t) + + tr1 := NewPersistentTaskTracker(db) + tr1.Start("t1", "task 1", AgentClaude) + tr1.Complete("t1", "done") + tr1.Start("t2", "task 2", AgentGemini) + tr1.Fail("t2", "timeout") + + // New tracker with empty memory. + tr2 := &TaskTracker{ + tasks: make(map[string]*TaskRecord), + maxConcurrent: defaultMaxConcurrent, + db: db, + } + + list := tr2.List() + if len(list) != 2 { + t.Fatalf("got %d tasks, want 2", len(list)) + } +} + +func TestPersistentTaskTracker_DBWriteThrough(t *testing.T) { + db := openTestTrackerDB(t) + tr := NewPersistentTaskTracker(db) + + tr.Start("t1", "test", AgentClaude) + + // Verify directly in DB. + dbRec, err := tasks.Get(db, "t1") + if err != nil { + t.Fatalf("tasks.Get: %v", err) + } + if dbRec == nil { + t.Fatal("task not found in DB after Start") + } + if dbRec.Status != "running" { + t.Errorf("DB status = %q, want running", dbRec.Status) + } + + tr.Complete("t1", "output") + dbRec, _ = tasks.Get(db, "t1") + if dbRec.Status != "completed" || dbRec.Output != "output" { + t.Errorf("DB after Complete: %+v", dbRec) + } +} diff --git a/channel/telegram/session.go b/channel/telegram/session.go index 08e41d96..577f6223 100644 --- a/channel/telegram/session.go +++ b/channel/telegram/session.go @@ -75,7 +75,7 @@ func NewSessionManager(cfg *config.Config, ch *Channel, p provider.Provider, pro provider: p, providerName: providerName, model: model, - taskTracker: tools.NewTaskTracker(), + taskTracker: openTaskTracker(cfg), } if len(deps) > 0 { d := deps[0] @@ -649,6 +649,14 @@ func (sm *SessionManager) openAuditLogger() *audit.Logger { return audit.OpenDefault(config.AuditDBPath()) } +func openTaskTracker(cfg *config.Config) *tools.TaskTracker { + if err := config.EnsureSourceDir("tasks"); err != nil { + slog.Warn("tasks: ensure dir failed", "error", err) + return tools.NewTaskTracker() + } + return tools.OpenPersistentTaskTracker(cfg.Tasks.Storage.Driver, cfg.TasksDataDSN()) +} + func generateSessionID() string { var b [16]byte if _, err := rand.Read(b[:]); err != nil { diff --git a/config/config.go b/config/config.go index 565d19dc..1fb9036c 100644 --- a/config/config.go +++ b/config/config.go @@ -37,6 +37,7 @@ type Config struct { Contacts *ContactsConfig `yaml:"contacts,omitempty"` Slack *SlackConfig `yaml:"slack,omitempty"` Scheduler *SchedulerConfig `yaml:"scheduler,omitempty"` + Tasks *TasksConfig `yaml:"tasks,omitempty"` } func (c *Config) ResolvedMode() Mode { @@ -202,6 +203,10 @@ type SchedulerConfig struct { Storage StorageConfig `yaml:"storage,omitempty"` } +type TasksConfig struct { + Storage StorageConfig `yaml:"storage,omitempty"` +} + type StorageConfig struct { Driver string `yaml:"driver,omitempty"` // "sqlite" or "postgres" DSN string `yaml:"dsn,omitempty"` @@ -229,8 +234,10 @@ func (c *Config) SourceDataDSN(source string) (string, error) { return c.ContactsDataDSN(), nil case "scheduler": return c.SchedulerDataDSN(), nil + case "tasks": + return c.TasksDataDSN(), nil default: - return "", fmt.Errorf("unknown source: %q (valid: gmail, whatsapp, history, user_memory, applenotes, imessage, usage, websearch, contacts, scheduler)", source) + return "", fmt.Errorf("unknown source: %q (valid: gmail, whatsapp, history, user_memory, applenotes, imessage, usage, websearch, contacts, scheduler, tasks)", source) } } @@ -329,6 +336,11 @@ func Default() *Config { Driver: "sqlite", }, }, + Tasks: &TasksConfig{ + Storage: StorageConfig{ + Driver: "sqlite", + }, + }, } cfg.applyDefaults() return cfg @@ -404,6 +416,12 @@ func (c *Config) applyDefaults() { if c.Scheduler.Storage.Driver == "" { c.Scheduler.Storage.Driver = "sqlite" } + if c.Tasks == nil { + c.Tasks = &TasksConfig{} + } + if c.Tasks.Storage.Driver == "" { + c.Tasks.Storage.Driver = "sqlite" + } if c.Daemon == nil { c.Daemon = &DaemonConfig{} } @@ -479,6 +497,13 @@ func (c *Config) ContactsDataDSN() string { return filepath.Join(SourceDir("contacts"), "data.db") } +func (c *Config) TasksDataDSN() string { + if c.Tasks != nil && c.Tasks.Storage.DSN != "" { + return c.Tasks.Storage.DSN + } + return filepath.Join(SourceDir("tasks"), "data.db") +} + func (c *Config) SchedulerDataDSN() string { if c.Scheduler != nil && c.Scheduler.Storage.DSN != "" { return c.Scheduler.Storage.DSN diff --git a/daemon/applenotes.go b/daemon/applenotes.go index ed70d597..5a4d4f34 100644 --- a/daemon/applenotes.go +++ b/daemon/applenotes.go @@ -15,7 +15,7 @@ const appleNotesSyncInterval = 30 * time.Second // runAppleNotesSync starts a goroutine that periodically syncs Apple Notes. // Only runs on macOS. Errors are sent on the returned channel. -func runAppleNotesSync(ctx context.Context, cfg *config.Config) <-chan error { +func runAppleNotesSync(ctx context.Context, cfg *config.Config, notifier *SyncNotifier) <-chan error { errCh := make(chan error, 1) go func() { @@ -49,7 +49,7 @@ func runAppleNotesSync(ctx context.Context, cfg *config.Config) <-chan error { defer db.Close() // Run initial sync immediately. - syncAppleNotes(db) + syncAppleNotes(db, notifier) ticker := time.NewTicker(appleNotesSyncInterval) defer ticker.Stop() @@ -60,7 +60,7 @@ func runAppleNotesSync(ctx context.Context, cfg *config.Config) <-chan error { slog.Info("applenotes: stopping sync") return case <-ticker.C: - syncAppleNotes(db) + syncAppleNotes(db, notifier) } } }() @@ -68,11 +68,14 @@ func runAppleNotesSync(ctx context.Context, cfg *config.Config) <-chan error { return errCh } -func syncAppleNotes(db *store.DB) { +func syncAppleNotes(db *store.DB, notifier *SyncNotifier) { result, err := ansrc.Sync(db, ansrc.SyncOptions{}) if err != nil { slog.Error("applenotes: sync error", "error", err) return } slog.Info("applenotes: sync complete", "synced", result.Synced, "skipped", result.Skipped, "errors", result.Errors) + if notifier != nil { + notifier.Notify("applenotes") + } } diff --git a/daemon/daemon.go b/daemon/daemon.go index 970cffd7..6d8ff37f 100644 --- a/daemon/daemon.go +++ b/daemon/daemon.go @@ -66,7 +66,9 @@ func (d *Daemon) Run(ctx context.Context) error { slog.Info("starting daemon") - client, db, err := newRiverClient(ctx, d.cfg) + notifier := NewSyncNotifier() + + client, db, err := newRiverClient(ctx, d.cfg, notifier) if err != nil { return fmt.Errorf("init river: %w", err) } @@ -80,7 +82,7 @@ func (d *Daemon) Run(ctx context.Context) error { slog.Info("river job queue started") if !d.skipScheduler { - d.scheduler = NewScheduler(d.cfg, d.river, d.jobsDB) + d.scheduler = NewScheduler(d.cfg, d.river, d.jobsDB, notifier) if err := d.scheduler.Start(ctx); err != nil { slog.Error("scheduler start error", "error", err) } @@ -88,13 +90,13 @@ func (d *Daemon) Run(ctx context.Context) error { var waErrCh, anErrCh, imErrCh, ctErrCh <-chan error if !d.skipWhatsApp { - waErrCh = runWhatsAppSync(ctx, d.cfg) + waErrCh = runWhatsAppSync(ctx, d.cfg, notifier) } if !d.skipAppleNotes { - anErrCh = runAppleNotesSync(ctx, d.cfg) + anErrCh = runAppleNotesSync(ctx, d.cfg, notifier) } if !d.skipIMessage { - imErrCh = runIMessageSync(ctx, d.cfg) + imErrCh = runIMessageSync(ctx, d.cfg, notifier) } if !d.skipContacts { ctErrCh = runContactsSync(ctx, d.cfg) diff --git a/daemon/imessage.go b/daemon/imessage.go index 8026dd0f..cc2352d0 100644 --- a/daemon/imessage.go +++ b/daemon/imessage.go @@ -13,7 +13,7 @@ import ( const iMessageSyncInterval = 30 * time.Second -func runIMessageSync(ctx context.Context, cfg *config.Config) <-chan error { +func runIMessageSync(ctx context.Context, cfg *config.Config, notifier *SyncNotifier) <-chan error { errCh := make(chan error, 1) go func() { @@ -46,7 +46,7 @@ func runIMessageSync(ctx context.Context, cfg *config.Config) <-chan error { } defer db.Close() - syncIMessage(db) + syncIMessage(db, notifier) ticker := time.NewTicker(iMessageSyncInterval) defer ticker.Stop() @@ -57,7 +57,7 @@ func runIMessageSync(ctx context.Context, cfg *config.Config) <-chan error { slog.Info("imessage: stopping sync") return case <-ticker.C: - syncIMessage(db) + syncIMessage(db, notifier) } } }() @@ -65,11 +65,14 @@ func runIMessageSync(ctx context.Context, cfg *config.Config) <-chan error { return errCh } -func syncIMessage(db *store.DB) { +func syncIMessage(db *store.DB, notifier *SyncNotifier) { result, err := imsrc.Sync(db, imsrc.SyncOptions{}) if err != nil { slog.Error("imessage: sync error", "error", err) return } slog.Info("imessage: sync complete", "synced", result.Synced, "skipped", result.Skipped, "errors", result.Errors) + if notifier != nil { + notifier.Notify("imessage") + } } diff --git a/daemon/jobs/gmail_sync.go b/daemon/jobs/gmail_sync.go index 8cd413c0..230b54da 100644 --- a/daemon/jobs/gmail_sync.go +++ b/daemon/jobs/gmail_sync.go @@ -20,7 +20,8 @@ func (GmailSyncArgs) Kind() string { return "gmail_sync" } type GmailSyncWorker struct { river.WorkerDefaults[GmailSyncArgs] - Cfg *config.Config + Cfg *config.Config + OnComplete func() // called after successful sync } func (w *GmailSyncWorker) Work(ctx context.Context, job *river.Job[GmailSyncArgs]) error { @@ -55,6 +56,9 @@ func (w *GmailSyncWorker) Work(ctx context.Context, job *river.Job[GmailSyncArgs } slog.Info("gmail sync complete", "fetched", result.Fetched, "skipped", result.Skipped, "errors", result.Errors) + if w.OnComplete != nil { + w.OnComplete() + } return nil } diff --git a/daemon/jobs/scheduled_task.go b/daemon/jobs/scheduled_task.go index 16c14b55..179a7cf1 100644 --- a/daemon/jobs/scheduled_task.go +++ b/daemon/jobs/scheduled_task.go @@ -16,14 +16,17 @@ import ( "github.com/73ai/openbotkit/config" "github.com/73ai/openbotkit/provider" "github.com/73ai/openbotkit/service/scheduler" + "github.com/73ai/openbotkit/service/tasks" "github.com/73ai/openbotkit/store" ) type ScheduledTaskArgs struct { - ScheduleID int64 `json:"schedule_id"` - Task string `json:"task"` - Channel string `json:"channel"` - ChannelMeta string `json:"channel_meta"` + ScheduleID int64 `json:"schedule_id"` + Task string `json:"task"` + Channel string `json:"channel"` + ChannelMeta string `json:"channel_meta"` + ModelTier string `json:"model_tier,omitempty"` + MaxBudgetUSD float64 `json:"max_budget_usd,omitempty"` } func (ScheduledTaskArgs) Kind() string { return "scheduled_task" } @@ -41,26 +44,38 @@ type ScheduledTaskWorker struct { Cfg *config.Config MakePusher PusherFactory RunAgentFunc AgentRunner + TasksDB *store.DB // optional: for recording task results } func (w *ScheduledTaskWorker) Work(ctx context.Context, job *river.Job[ScheduledTaskArgs]) error { slog.Info("running scheduled task", "schedule_id", job.Args.ScheduleID, "attempt", job.Attempt) + taskID := fmt.Sprintf("sched-%d-%d", job.Args.ScheduleID, time.Now().UnixMilli()) + w.recordTaskStart(taskID, job.Args.Task) var meta scheduler.ChannelMeta if err := json.Unmarshal([]byte(job.Args.ChannelMeta), &meta); err != nil { return fmt.Errorf("parse channel meta: %w", err) } - runAgent := w.runAgent + var result string + var err error if w.RunAgentFunc != nil { - runAgent = w.RunAgentFunc + result, err = w.RunAgentFunc(ctx, job.Args.Task) + } else { + result, err = w.runAgentWithBudget(ctx, job.Args.Task, job.Args.ModelTier, job.Args.MaxBudgetUSD) } - result, err := runAgent(ctx, job.Args.Task) if err != nil { slog.Error("scheduled task agent failed", "schedule_id", job.Args.ScheduleID, "error", err) w.updateLastRun(job.Args.ScheduleID, err.Error()) + w.recordTaskFailed(taskID, err.Error()) - if job.Attempt >= 2 { + apiErr := provider.ClassifyError(err) + if apiErr.Kind == provider.ErrorAuth || apiErr.Kind == provider.ErrorContextWindow { + w.notifyFailure(ctx, job.Args.Channel, meta, job.Args.ScheduleID, err) + return river.JobCancel(err) + } + + if job.Attempt >= job.MaxAttempts { w.notifyFailure(ctx, job.Args.Channel, meta, job.Args.ScheduleID, err) return nil } @@ -80,17 +95,35 @@ func (w *ScheduledTaskWorker) Work(ctx context.Context, job *river.Job[Scheduled } w.updateLastRun(job.Args.ScheduleID, "") + w.recordTaskCompleted(taskID, result) w.maybeMarkCompleted(job.Args.ScheduleID) slog.Info("scheduled task complete", "schedule_id", job.Args.ScheduleID) return nil } -func (w *ScheduledTaskWorker) NextRetryAt(_ *river.Job[ScheduledTaskArgs]) time.Time { +func (w *ScheduledTaskWorker) NextRetry(job *river.Job[ScheduledTaskArgs]) time.Time { + if len(job.Errors) == 0 { + return time.Now().Add(15 * time.Minute) + } + // Auth and context-window errors are cancelled in Work() via + // river.JobCancel, so NextRetry is only called for retryable errors. + lastErr := job.Errors[len(job.Errors)-1] + apiErr := provider.ClassifyError(fmt.Errorf("%s", lastErr.Error)) + if apiErr.Kind == provider.ErrorRetryable && apiErr.StatusCode == 429 { + return time.Now().Add(30 * time.Minute) + } + if apiErr.Kind == provider.ErrorRetryable { + return time.Now().Add(10 * time.Minute) // 5xx + } return time.Now().Add(15 * time.Minute) } func (w *ScheduledTaskWorker) runAgent(ctx context.Context, task string) (string, error) { + return w.runAgentWithBudget(ctx, task, "", 0) +} + +func (w *ScheduledTaskWorker) runAgentWithBudget(ctx context.Context, task string, modelTier string, maxBudget float64) (string, error) { if w.Cfg == nil || w.Cfg.Models == nil || w.Cfg.Models.Default == "" { return "", fmt.Errorf("no LLM model configured") } @@ -100,14 +133,14 @@ func (w *ScheduledTaskWorker) runAgent(ctx context.Context, task string) (string return "", fmt.Errorf("create provider registry: %w", err) } - providerName, modelName, err := provider.ParseModelSpec(w.Cfg.Models.Default) - if err != nil { - return "", fmt.Errorf("parse model spec: %w", err) + router := provider.NewRouter(registry, w.Cfg.Models) + tier := provider.TierFast + if modelTier != "" { + tier = provider.ModelTier(modelTier) } - - p, ok := registry.Get(providerName) - if !ok { - return "", fmt.Errorf("provider %q not found", providerName) + p, modelName, err := router.Resolve(tier) + if err != nil { + return "", fmt.Errorf("resolve model tier %q: %w", tier, err) } toolReg := tools.NewScheduledTaskRegistry() @@ -127,7 +160,13 @@ func (w *ScheduledTaskWorker) runAgent(ctx context.Context, task string) (string identity := "You are a scheduled task agent. Execute the task and return a concise result.\n" blocks := tools.BuildSystemBlocks(identity, toolReg) - a := agent.New(p, modelName, toolReg, agent.WithSystemBlocks(blocks)) + opts := []agent.Option{agent.WithSystemBlocks(blocks)} + if maxBudget > 0 { + bt := agent.NewBudgetTracker(maxBudget, nil) + opts = append(opts, agent.WithUsageRecorder(bt), agent.WithBudgetChecker(bt)) + } + + a := agent.New(p, modelName, toolReg, opts...) return a.Run(ctx, task) } @@ -185,6 +224,36 @@ func (w *ScheduledTaskWorker) notifyFailure(ctx context.Context, ch string, meta } } +func (w *ScheduledTaskWorker) recordTaskStart(taskID, task string) { + if w.TasksDB == nil { + return + } + if err := tasks.Insert(w.TasksDB, &tasks.TaskRecord{ + ID: taskID, Task: task, Agent: "scheduled", + Status: "running", StartedAt: time.Now().UTC(), + }); err != nil { + slog.Warn("tasks: record start failed", "id", taskID, "error", err) + } +} + +func (w *ScheduledTaskWorker) recordTaskCompleted(taskID, output string) { + if w.TasksDB == nil { + return + } + if err := tasks.SetCompleted(w.TasksDB, taskID, output); err != nil { + slog.Warn("tasks: record completed failed", "id", taskID, "error", err) + } +} + +func (w *ScheduledTaskWorker) recordTaskFailed(taskID, errMsg string) { + if w.TasksDB == nil { + return + } + if err := tasks.SetFailed(w.TasksDB, taskID, errMsg); err != nil { + slog.Warn("tasks: record failed failed", "id", taskID, "error", err) + } +} + func openAuditLogger() *audit.Logger { return audit.OpenDefault(config.AuditDBPath()) } diff --git a/daemon/jobs/scheduled_task_worker_test.go b/daemon/jobs/scheduled_task_worker_test.go index e6ea34ca..b0993dd0 100644 --- a/daemon/jobs/scheduled_task_worker_test.go +++ b/daemon/jobs/scheduled_task_worker_test.go @@ -5,18 +5,54 @@ import ( "time" "github.com/riverqueue/river" + "github.com/riverqueue/river/rivertype" "github.com/73ai/openbotkit/config" ) -func TestScheduledTaskWorkerNextRetryAt(t *testing.T) { - w := &ScheduledTaskWorker{} - before := time.Now().Add(14 * time.Minute) - retryAt := w.NextRetryAt(&river.Job[ScheduledTaskArgs]{}) - after := time.Now().Add(16 * time.Minute) +func TestScheduledTaskWorkerNextRetry(t *testing.T) { + tests := []struct { + name string + errors []rivertype.AttemptError + minDur time.Duration + maxDur time.Duration + }{ + { + name: "no errors defaults to 15min", + errors: nil, + minDur: 14 * time.Minute, + maxDur: 16 * time.Minute, + }, + // Auth and context-window errors are cancelled in Work() via + // river.JobCancel, so NextRetry is never called for those. + { + name: "rate limit 429 delays 30min", + errors: []rivertype.AttemptError{{Error: "API error (HTTP 429): rate limited"}}, + minDur: 29 * time.Minute, + maxDur: 31 * time.Minute, + }, + { + name: "server error 500 delays 10min", + errors: []rivertype.AttemptError{{Error: "API error (HTTP 500): internal server error"}}, + minDur: 9 * time.Minute, + maxDur: 11 * time.Minute, + }, + } - if retryAt.Before(before) || retryAt.After(after) { - t.Errorf("NextRetryAt should be ~15 min from now, got %v", retryAt) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := &ScheduledTaskWorker{} + job := &river.Job[ScheduledTaskArgs]{ + JobRow: &rivertype.JobRow{Errors: tt.errors}, + } + now := time.Now() + retryAt := w.NextRetry(job) + minExpected := now.Add(tt.minDur) + maxExpected := now.Add(tt.maxDur) + if retryAt.Before(minExpected) || retryAt.After(maxExpected) { + t.Errorf("NextRetry = %v, want between %v and %v", retryAt, minExpected, maxExpected) + } + }) } } diff --git a/daemon/river.go b/daemon/river.go index 1c129e02..a84305c5 100644 --- a/daemon/river.go +++ b/daemon/river.go @@ -14,7 +14,7 @@ import ( "github.com/73ai/openbotkit/daemon/jobs" ) -func newRiverClient(ctx context.Context, cfg *config.Config) (*river.Client[*sql.Tx], *sql.DB, error) { +func newRiverClient(ctx context.Context, cfg *config.Config, notifier *SyncNotifier) (*river.Client[*sql.Tx], *sql.DB, error) { dsn := cfg.JobsDBDSN() db, err := sql.Open("sqlite", dsn) @@ -36,7 +36,10 @@ func newRiverClient(ctx context.Context, cfg *config.Config) (*river.Client[*sql } workers := river.NewWorkers() - river.AddWorker(workers, &jobs.GmailSyncWorker{Cfg: cfg}) + river.AddWorker(workers, &jobs.GmailSyncWorker{ + Cfg: cfg, + OnComplete: func() { notifier.Notify("gmail") }, + }) river.AddWorker(workers, &jobs.ReminderWorker{}) river.AddWorker(workers, &jobs.ScheduledTaskWorker{Cfg: cfg}) diff --git a/daemon/river_test.go b/daemon/river_test.go index 304a796f..6c8403d6 100644 --- a/daemon/river_test.go +++ b/daemon/river_test.go @@ -18,7 +18,7 @@ func TestNewRiverClient(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - client, db, err := newRiverClient(ctx, cfg) + client, db, err := newRiverClient(ctx, cfg, NewSyncNotifier()) if err != nil { t.Fatalf("newRiverClient failed: %v", err) } diff --git a/daemon/scheduler.go b/daemon/scheduler.go index 3fc33897..b95b48a9 100644 --- a/daemon/scheduler.go +++ b/daemon/scheduler.go @@ -6,6 +6,8 @@ import ( "encoding/json" "fmt" "log/slog" + "sort" + "strings" "sync" "time" @@ -19,21 +21,23 @@ import ( ) type Scheduler struct { - cfg *config.Config - river *river.Client[*sql.Tx] - jobsDB *sql.DB - cron *cron.Cron - mu sync.Mutex - entries map[int64]cron.EntryID - ctx context.Context + cfg *config.Config + river *river.Client[*sql.Tx] + jobsDB *sql.DB + cron *cron.Cron + mu sync.Mutex + entries map[int64]cron.EntryID + ctx context.Context + notifier *SyncNotifier } -func NewScheduler(cfg *config.Config, riverClient *river.Client[*sql.Tx], jobsDB *sql.DB) *Scheduler { +func NewScheduler(cfg *config.Config, riverClient *river.Client[*sql.Tx], jobsDB *sql.DB, notifier *SyncNotifier) *Scheduler { return &Scheduler{ - cfg: cfg, - river: riverClient, - jobsDB: jobsDB, - entries: make(map[int64]cron.EntryID), + cfg: cfg, + river: riverClient, + jobsDB: jobsDB, + entries: make(map[int64]cron.EntryID), + notifier: notifier, } } @@ -58,6 +62,9 @@ func (s *Scheduler) Start(ctx context.Context) error { go s.reloadLoop(ctx) go s.oneShotLoop(ctx) + if s.notifier != nil { + go s.reactiveCheckLoop(ctx) + } slog.Info("scheduler started") return nil @@ -151,10 +158,12 @@ func (s *Scheduler) loadSchedules() error { func (s *Scheduler) addCronEntry(sched scheduler.Schedule) (cron.EntryID, error) { metaJSON, _ := json.Marshal(sched.ChannelMeta) args := jobs.ScheduledTaskArgs{ - ScheduleID: sched.ID, - Task: sched.Task, - Channel: sched.Channel, - ChannelMeta: string(metaJSON), + ScheduleID: sched.ID, + Task: sched.Task, + Channel: sched.Channel, + ChannelMeta: string(metaJSON), + ModelTier: sched.ModelTier, + MaxBudgetUSD: sched.MaxBudgetUSD, } return s.cron.AddFunc(sched.CronExpr, func() { @@ -164,7 +173,7 @@ func (s *Scheduler) addCronEntry(sched scheduler.Schedule) (cron.EntryID, error) return } _, err = s.river.InsertTx(s.ctx, tx, args, &river.InsertOpts{ - MaxAttempts: 2, + MaxAttempts: 3, }) if err != nil { tx.Rollback() @@ -192,10 +201,12 @@ func (s *Scheduler) pollOneShot(ctx context.Context) error { for _, sched := range due { metaJSON, _ := json.Marshal(sched.ChannelMeta) args := jobs.ScheduledTaskArgs{ - ScheduleID: sched.ID, - Task: sched.Task, - Channel: sched.Channel, - ChannelMeta: string(metaJSON), + ScheduleID: sched.ID, + Task: sched.Task, + Channel: sched.Channel, + ChannelMeta: string(metaJSON), + ModelTier: sched.ModelTier, + MaxBudgetUSD: sched.MaxBudgetUSD, } tx, err := s.jobsDB.Begin() @@ -204,7 +215,7 @@ func (s *Scheduler) pollOneShot(ctx context.Context) error { continue } _, err = s.river.InsertTx(ctx, tx, args, &river.InsertOpts{ - MaxAttempts: 2, + MaxAttempts: 3, }) if err != nil { tx.Rollback() @@ -226,6 +237,136 @@ func (s *Scheduler) pollOneShot(ctx context.Context) error { return nil } +func (s *Scheduler) reactiveCheckLoop(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case sig := <-s.notifier.C(): + if err := s.checkReactiveTriggers(ctx, sig.Source); err != nil { + slog.Error("scheduler: reactive check failed", "source", sig.Source, "error", err) + } + } + } +} + +func (s *Scheduler) checkReactiveTriggers(ctx context.Context, source string) error { + schedDB, err := s.openDB() + if err != nil { + return fmt.Errorf("open scheduler db: %w", err) + } + defer schedDB.Close() + + schedules, err := scheduler.ListEnabledReactive(schedDB, source) + if err != nil { + return fmt.Errorf("list reactive: %w", err) + } + if len(schedules) == 0 { + return nil + } + + dsn, err := s.cfg.SourceDataDSN(source) + if err != nil { + return fmt.Errorf("source dsn: %w", err) + } + sourceDB, err := store.Open(store.SQLiteConfig(dsn)) + if err != nil { + return fmt.Errorf("open source db %q: %w", source, err) + } + defer sourceDB.Close() + + return s.checkReactiveTriggersWithDB(ctx, schedDB, sourceDB, schedules) +} + +func (s *Scheduler) checkReactiveTriggersWithDB(ctx context.Context, schedDB *store.DB, sourceDB *store.DB, schedules []scheduler.Schedule) error { + for _, sched := range schedules { + match, err := scheduler.CheckTrigger(sourceDB, sched.TriggerSource, sched.TriggerQuery, sched.LastTriggerID) + if err != nil { + slog.Error("scheduler: trigger check failed", "id", sched.ID, "error", err) + continue + } + if match == nil { + continue + } + + // Build augmented task with matched data summary. + task := sched.Task + "\n\nTriggered by " + fmt.Sprintf("%d", len(match.Rows)) + " new matching row(s) from " + sched.TriggerSource + ":\n" + for i, row := range match.Rows { + if i >= 5 { + task += fmt.Sprintf("... and %d more\n", len(match.Rows)-5) + break + } + task += formatRow(row) + } + + metaJSON, _ := json.Marshal(sched.ChannelMeta) + args := jobs.ScheduledTaskArgs{ + ScheduleID: sched.ID, + Task: task, + Channel: sched.Channel, + ChannelMeta: string(metaJSON), + ModelTier: sched.ModelTier, + MaxBudgetUSD: sched.MaxBudgetUSD, + } + + tx, err := s.jobsDB.Begin() + if err != nil { + slog.Error("scheduler: begin tx for reactive", "error", err) + continue + } + _, err = s.river.InsertTx(ctx, tx, args, &river.InsertOpts{ + MaxAttempts: 3, + }) + if err != nil { + tx.Rollback() + slog.Error("scheduler: insert reactive job", "schedule_id", sched.ID, "error", err) + continue + } + if err := tx.Commit(); err != nil { + slog.Error("scheduler: commit reactive tx", "error", err) + continue + } + + if err := scheduler.UpdateLastTriggerID(schedDB, sched.ID, match.MaxID); err != nil { + slog.Error("scheduler: update watermark", "id", sched.ID, "error", err) + } + if err := scheduler.UpdateLastRun(schedDB, sched.ID, time.Now().UTC(), ""); err != nil { + slog.Error("scheduler: update last run", "id", sched.ID, "error", err) + } + + slog.Info("scheduler: enqueued reactive task", "schedule_id", sched.ID, "matched_rows", len(match.Rows), "watermark", match.MaxID) + } + return nil +} + +// CheckReactiveTriggersForTest exposes reactive trigger checking for tests. +func (s *Scheduler) CheckReactiveTriggersForTest(ctx context.Context, source string, sourceDB *store.DB) error { + schedDB, err := s.openDB() + if err != nil { + return fmt.Errorf("open scheduler db: %w", err) + } + defer schedDB.Close() + + schedules, err := scheduler.ListEnabledReactive(schedDB, source) + if err != nil { + return fmt.Errorf("list reactive: %w", err) + } + return s.checkReactiveTriggersWithDB(ctx, schedDB, sourceDB, schedules) +} + +func formatRow(row map[string]string) string { + keys := make([]string, 0, len(row)) + for k := range row { + keys = append(keys, k) + } + sort.Strings(keys) + parts := make([]string, 0, len(keys)) + for _, k := range keys { + parts = append(parts, k+": "+row[k]) + } + return " " + strings.Join(parts, " | ") + "\n" +} + func (s *Scheduler) isValidFrequency(cronExpr string) bool { parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow) sched, err := parser.Parse(cronExpr) diff --git a/daemon/sync_signal.go b/daemon/sync_signal.go new file mode 100644 index 00000000..2574cea4 --- /dev/null +++ b/daemon/sync_signal.go @@ -0,0 +1,24 @@ +package daemon + +type SyncSignal struct { + Source string +} + +type SyncNotifier struct { + ch chan SyncSignal +} + +func NewSyncNotifier() *SyncNotifier { + return &SyncNotifier{ch: make(chan SyncSignal, 16)} +} + +func (n *SyncNotifier) Notify(source string) { + select { + case n.ch <- SyncSignal{Source: source}: + default: + } +} + +func (n *SyncNotifier) C() <-chan SyncSignal { + return n.ch +} diff --git a/daemon/sync_signal_test.go b/daemon/sync_signal_test.go new file mode 100644 index 00000000..a13d3f2c --- /dev/null +++ b/daemon/sync_signal_test.go @@ -0,0 +1,24 @@ +package daemon + +import "testing" + +func TestSyncNotifier_SendReceive(t *testing.T) { + n := NewSyncNotifier() + n.Notify("gmail") + sig := <-n.C() + if sig.Source != "gmail" { + t.Errorf("got source %q, want gmail", sig.Source) + } +} + +func TestSyncNotifier_NonBlockingWhenFull(t *testing.T) { + n := NewSyncNotifier() + // Fill the buffer (cap 16). + for i := 0; i < 20; i++ { + n.Notify("test") + } + // Should not panic or block. + if len(n.ch) != 16 { + t.Errorf("expected buffer full at 16, got %d", len(n.ch)) + } +} diff --git a/daemon/whatsapp.go b/daemon/whatsapp.go index d72bd40b..8aa97911 100644 --- a/daemon/whatsapp.go +++ b/daemon/whatsapp.go @@ -3,6 +3,7 @@ package daemon import ( "context" "log/slog" + "time" "github.com/73ai/openbotkit/config" wasrc "github.com/73ai/openbotkit/source/whatsapp" @@ -11,7 +12,7 @@ import ( // runWhatsAppSync starts a WhatsApp sync goroutine that runs until ctx is cancelled. // Errors are sent on the returned channel (non-blocking). -func runWhatsAppSync(ctx context.Context, cfg *config.Config) <-chan error { +func runWhatsAppSync(ctx context.Context, cfg *config.Config, notifier *SyncNotifier) <-chan error { errCh := make(chan error, 1) go func() { @@ -45,6 +46,23 @@ func runWhatsAppSync(ctx context.Context, cfg *config.Config) <-chan error { } defer db.Close() + // WhatsApp uses streaming sync (Follow: true), so notify periodically + // while messages arrive, matching the cadence of other sync sources. + if notifier != nil { + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + notifier.Notify("whatsapp") + } + } + }() + } + slog.Info("whatsapp: starting sync") result, err := wasrc.Sync(ctx, client, db, wasrc.SyncOptions{ Follow: true, diff --git a/internal/cli/chat.go b/internal/cli/chat.go index d4e43503..8d21fef1 100644 --- a/internal/cli/chat.go +++ b/internal/cli/chat.go @@ -102,7 +102,9 @@ var chatCmd = &cobra.Command{ })) // Register delegate_task if external AI CLIs are available. - registerDelegateTool(toolReg, ch) + tracker := openTaskTracker(cfg) + defer tracker.Close() + registerDelegateTool(toolReg, ch, tracker) // Register Slack tools if configured. registerSlackTools(cfg, toolReg, ch) @@ -244,13 +246,20 @@ func registerSlackTools(cfg *config.Config, reg *tools.Registry, ch *clicli.Chan reg.Register(tools.NewSlackReactTool(deps)) } -func registerDelegateTool(reg *tools.Registry, ch *clicli.Channel) { +func openTaskTracker(cfg *config.Config) *tools.TaskTracker { + if err := config.EnsureSourceDir("tasks"); err != nil { + slog.Warn("tasks: ensure dir failed", "error", err) + return tools.NewTaskTracker() + } + return tools.OpenPersistentTaskTracker(cfg.Tasks.Storage.Driver, cfg.TasksDataDSN()) +} + +func registerDelegateTool(reg *tools.Registry, ch *clicli.Channel, tracker *tools.TaskTracker) { agents := tools.DetectAgents() if len(agents) == 0 { return } inter := NewCLIInteractor(ch) - tracker := tools.NewTaskTracker() reg.Register(tools.NewDelegateTaskTool(tools.DelegateTaskConfig{ Interactor: inter, Agents: agents, diff --git a/internal/cli/usage.go b/internal/cli/usage.go index 5a1d8c9c..95439f8e 100644 --- a/internal/cli/usage.go +++ b/internal/cli/usage.go @@ -3,12 +3,12 @@ package cli import ( "encoding/json" "fmt" - "math" "os" "text/tabwriter" "time" "github.com/73ai/openbotkit/config" + "github.com/73ai/openbotkit/provider" usagesrc "github.com/73ai/openbotkit/service/usage" "github.com/73ai/openbotkit/store" "github.com/spf13/cobra" @@ -152,54 +152,11 @@ func formatTokens(n int64) string { return fmt.Sprintf("%d", n) } -// Pricing per million tokens (input, output, cache_read). -// Cache writes are charged at input rate. -var modelPricing = map[string][3]float64{ - "claude-sonnet-4-6": {3.0, 15.0, 0.30}, - "claude-sonnet-4-20250514": {3.0, 15.0, 0.30}, - "claude-haiku-4-5": {0.80, 4.0, 0.08}, - "claude-opus-4-6": {15.0, 75.0, 1.50}, - "gpt-4o": {2.50, 10.0, 1.25}, - "gpt-4o-mini": {0.15, 0.60, 0.075}, - "gpt-4.1": {2.00, 8.00, 0.50}, - "gpt-4.1-mini": {0.40, 1.60, 0.10}, - "gpt-4.1-nano": {0.10, 0.40, 0.025}, - "gemini-2.5-pro": {1.25, 10.0, 0.3125}, - "gemini-2.5-flash": {0.15, 0.60, 0.0375}, -} - func estimateCost(r usagesrc.AggregatedUsage) float64 { - pricing, ok := modelPricing[r.Model] - if !ok { - // Try prefix matching for versioned model names. - // Use longest match to avoid "gpt-4o" matching "gpt-4o-mini-*". - bestLen := 0 - for prefix, p := range modelPricing { - if len(prefix) > bestLen && len(r.Model) >= len(prefix) && r.Model[:len(prefix)] == prefix { - pricing = p - bestLen = len(prefix) - ok = true - } - } - } - if !ok { - return 0 - } - - inputRate := pricing[0] / 1_000_000 - outputRate := pricing[1] / 1_000_000 - cacheReadRate := pricing[2] / 1_000_000 - - // Non-cached input tokens = total input - cache_read. - nonCachedInput := r.InputTokens - r.CacheReadTokens - if nonCachedInput < 0 { - nonCachedInput = 0 - } - - cost := float64(nonCachedInput)*inputRate + - float64(r.OutputTokens)*outputRate + - float64(r.CacheReadTokens)*cacheReadRate + - float64(r.CacheWriteTokens)*inputRate // cache writes charged at input rate - - return math.Round(cost*100) / 100 + return provider.EstimateCost(r.Model, provider.Usage{ + InputTokens: int(r.InputTokens), + OutputTokens: int(r.OutputTokens), + CacheReadTokens: int(r.CacheReadTokens), + CacheWriteTokens: int(r.CacheWriteTokens), + }) } diff --git a/provider/pricing.go b/provider/pricing.go new file mode 100644 index 00000000..0b60a1d6 --- /dev/null +++ b/provider/pricing.go @@ -0,0 +1,54 @@ +package provider + +import "math" + +// ModelPricing maps model names to per-million-token rates: +// [input, output, cache_read]. Cache writes are charged at input rate. +var ModelPricing = map[string][3]float64{ + "claude-sonnet-4-6": {3.0, 15.0, 0.30}, + "claude-sonnet-4-20250514": {3.0, 15.0, 0.30}, + "claude-haiku-4-5": {0.80, 4.0, 0.08}, + "claude-opus-4-6": {15.0, 75.0, 1.50}, + "gpt-4o": {2.50, 10.0, 1.25}, + "gpt-4o-mini": {0.15, 0.60, 0.075}, + "gpt-4.1": {2.00, 8.00, 0.50}, + "gpt-4.1-mini": {0.40, 1.60, 0.10}, + "gpt-4.1-nano": {0.10, 0.40, 0.025}, + "gemini-2.5-pro": {1.25, 10.0, 0.3125}, + "gemini-2.5-flash": {0.15, 0.60, 0.0375}, +} + +// EstimateCost calculates the estimated cost for a model usage. +// Returns 0 for unknown models. Uses prefix matching for versioned names. +func EstimateCost(model string, usage Usage) float64 { + pricing, ok := ModelPricing[model] + if !ok { + bestLen := 0 + for prefix, p := range ModelPricing { + if len(prefix) > bestLen && len(model) >= len(prefix) && model[:len(prefix)] == prefix { + pricing = p + bestLen = len(prefix) + ok = true + } + } + } + if !ok { + return 0 + } + + inputRate := pricing[0] / 1_000_000 + outputRate := pricing[1] / 1_000_000 + cacheReadRate := pricing[2] / 1_000_000 + + nonCachedInput := int64(usage.InputTokens) - int64(usage.CacheReadTokens) + if nonCachedInput < 0 { + nonCachedInput = 0 + } + + cost := float64(nonCachedInput)*inputRate + + float64(usage.OutputTokens)*outputRate + + float64(usage.CacheReadTokens)*cacheReadRate + + float64(usage.CacheWriteTokens)*inputRate + + return math.Round(cost*100) / 100 +} diff --git a/provider/pricing_test.go b/provider/pricing_test.go new file mode 100644 index 00000000..819b7538 --- /dev/null +++ b/provider/pricing_test.go @@ -0,0 +1,44 @@ +package provider + +import "testing" + +func TestEstimateCost_KnownModel(t *testing.T) { + usage := Usage{InputTokens: 1_000_000, OutputTokens: 1_000_000} + cost := EstimateCost("claude-sonnet-4-6", usage) + // 3.0 input + 15.0 output = 18.0 + if cost != 18.0 { + t.Errorf("cost = %f, want 18.0", cost) + } +} + +func TestEstimateCost_UnknownModel(t *testing.T) { + usage := Usage{InputTokens: 1000, OutputTokens: 1000} + cost := EstimateCost("unknown-model-v1", usage) + if cost != 0 { + t.Errorf("cost = %f, want 0", cost) + } +} + +func TestEstimateCost_PrefixMatching(t *testing.T) { + usage := Usage{InputTokens: 1_000_000, OutputTokens: 1_000_000} + cost := EstimateCost("claude-sonnet-4-6-20260101", usage) + if cost != 18.0 { + t.Errorf("prefix match cost = %f, want 18.0", cost) + } +} + +func TestEstimateCost_WithCache(t *testing.T) { + usage := Usage{ + InputTokens: 1_000_000, + OutputTokens: 500_000, + CacheReadTokens: 200_000, + } + cost := EstimateCost("claude-sonnet-4-6", usage) + // non-cached input: 800k * 3.0/M = 2.4 + // output: 500k * 15.0/M = 7.5 + // cache read: 200k * 0.30/M = 0.06 + // total = 9.96 + if cost != 9.96 { + t.Errorf("cost = %f, want 9.96", cost) + } +} diff --git a/provider/router.go b/provider/router.go index 690bd574..41654a6a 100644 --- a/provider/router.go +++ b/provider/router.go @@ -30,7 +30,7 @@ func NewRouter(registry *Registry, models *config.ModelsConfig) *Router { // Chat routes a request to the appropriate provider and model based on tier. func (r *Router) Chat(ctx context.Context, tier ModelTier, req ChatRequest) (*ChatResponse, error) { - p, model, err := r.resolve(tier) + p, model, err := r.Resolve(tier) if err != nil { return nil, err } @@ -40,7 +40,7 @@ func (r *Router) Chat(ctx context.Context, tier ModelTier, req ChatRequest) (*Ch // StreamChat routes a streaming request to the appropriate provider. func (r *Router) StreamChat(ctx context.Context, tier ModelTier, req ChatRequest) (<-chan StreamEvent, error) { - p, model, err := r.resolve(tier) + p, model, err := r.Resolve(tier) if err != nil { return nil, err } @@ -48,9 +48,9 @@ func (r *Router) StreamChat(ctx context.Context, tier ModelTier, req ChatRequest return p.StreamChat(ctx, req) } -// resolve returns the provider and model for the given tier. +// Resolve returns the provider and model for the given tier. // Cascade order: nano → fast → default, complex → default. -func (r *Router) resolve(tier ModelTier) (Provider, string, error) { +func (r *Router) Resolve(tier ModelTier) (Provider, string, error) { spec := r.specForTier(tier) if spec == "" && tier == TierNano { spec = r.specForTier(TierFast) diff --git a/service/scheduler/schema.go b/service/scheduler/schema.go index 842295b7..e37e6514 100644 --- a/service/scheduler/schema.go +++ b/service/scheduler/schema.go @@ -49,6 +49,19 @@ func Migrate(db *store.DB) error { if db.IsPostgres() { schema = schemaPostgres } - _, err := db.Exec(schema) - return err + if _, err := db.Exec(schema); err != nil { + return err + } + // Best-effort ALTER TABLE for existing databases. + // SQLite: duplicate column errors are safe to ignore. + for _, stmt := range []string{ + "ALTER TABLE schedules ADD COLUMN model_tier TEXT DEFAULT 'fast'", + "ALTER TABLE schedules ADD COLUMN max_budget_usd REAL DEFAULT 0", + "ALTER TABLE schedules ADD COLUMN trigger_source TEXT", + "ALTER TABLE schedules ADD COLUMN trigger_query TEXT", + "ALTER TABLE schedules ADD COLUMN last_trigger_id INTEGER DEFAULT 0", + } { + db.Exec(stmt) // ignore "duplicate column" errors + } + return nil } diff --git a/service/scheduler/store.go b/service/scheduler/store.go index e2148dc8..74228151 100644 --- a/service/scheduler/store.go +++ b/service/scheduler/store.go @@ -23,9 +23,9 @@ func Create(db *store.DB, s *Schedule) (int64, error) { } res, err := db.Exec( - db.Rebind(`INSERT INTO schedules (type, cron_expr, scheduled_at, task, channel, channel_meta, timezone, description, enabled) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`), - string(s.Type), s.CronExpr, scheduledAt, s.Task, s.Channel, string(metaJSON), s.Timezone, s.Description, 1, + db.Rebind(`INSERT INTO schedules (type, cron_expr, scheduled_at, task, channel, channel_meta, timezone, description, enabled, model_tier, max_budget_usd, trigger_source, trigger_query, last_trigger_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`), + string(s.Type), s.CronExpr, scheduledAt, s.Task, s.Channel, string(metaJSON), s.Timezone, s.Description, 1, s.ModelTier, s.MaxBudgetUSD, s.TriggerSource, s.TriggerQuery, s.LastTriggerID, ) if err != nil { return 0, fmt.Errorf("insert schedule: %w", err) @@ -35,7 +35,7 @@ func Create(db *store.DB, s *Schedule) (int64, error) { func Get(db *store.DB, id int64) (*Schedule, error) { row := db.QueryRow( - db.Rebind(`SELECT id, type, cron_expr, scheduled_at, task, channel, channel_meta, timezone, description, enabled, last_run_at, last_error, created_at, completed_at + db.Rebind(`SELECT id, type, cron_expr, scheduled_at, task, channel, channel_meta, timezone, description, enabled, last_run_at, last_error, created_at, completed_at, model_tier, max_budget_usd, trigger_source, trigger_query, last_trigger_id FROM schedules WHERE id = ?`), id, ) @@ -44,7 +44,7 @@ func Get(db *store.DB, id int64) (*Schedule, error) { func List(db *store.DB) ([]Schedule, error) { rows, err := db.Query( - `SELECT id, type, cron_expr, scheduled_at, task, channel, channel_meta, timezone, description, enabled, last_run_at, last_error, created_at, completed_at + `SELECT id, type, cron_expr, scheduled_at, task, channel, channel_meta, timezone, description, enabled, last_run_at, last_error, created_at, completed_at, model_tier, max_budget_usd, trigger_source, trigger_query, last_trigger_id FROM schedules ORDER BY created_at`) if err != nil { return nil, fmt.Errorf("list schedules: %w", err) @@ -55,7 +55,7 @@ func List(db *store.DB) ([]Schedule, error) { func ListEnabled(db *store.DB) ([]Schedule, error) { rows, err := db.Query( - db.Rebind(`SELECT id, type, cron_expr, scheduled_at, task, channel, channel_meta, timezone, description, enabled, last_run_at, last_error, created_at, completed_at + db.Rebind(`SELECT id, type, cron_expr, scheduled_at, task, channel, channel_meta, timezone, description, enabled, last_run_at, last_error, created_at, completed_at, model_tier, max_budget_usd, trigger_source, trigger_query, last_trigger_id FROM schedules WHERE enabled = ? ORDER BY created_at`), 1) if err != nil { return nil, fmt.Errorf("list enabled schedules: %w", err) @@ -66,7 +66,7 @@ func ListEnabled(db *store.DB) ([]Schedule, error) { func ListDueOneShot(db *store.DB, now time.Time) ([]Schedule, error) { rows, err := db.Query( - db.Rebind(`SELECT id, type, cron_expr, scheduled_at, task, channel, channel_meta, timezone, description, enabled, last_run_at, last_error, created_at, completed_at + db.Rebind(`SELECT id, type, cron_expr, scheduled_at, task, channel, channel_meta, timezone, description, enabled, last_run_at, last_error, created_at, completed_at, model_tier, max_budget_usd, trigger_source, trigger_query, last_trigger_id FROM schedules WHERE type = 'one_shot' AND scheduled_at <= ? AND completed_at IS NULL AND enabled = ? ORDER BY scheduled_at`), now.UTC().Format(timeFormat), 1, ) @@ -134,11 +134,15 @@ func scanSchedule(row scannable) (*Schedule, error) { var s Schedule var typ, cronExpr, scheduledAt, channel, metaJSON, tz sql.NullString var desc, lastRunAt, lastError, createdAt, completedAt sql.NullString + var modelTier, triggerSource, triggerQuery sql.NullString + var maxBudget sql.NullFloat64 + var lastTriggerID sql.NullInt64 var enabled any err := row.Scan( &s.ID, &typ, &cronExpr, &scheduledAt, &s.Task, &channel, &metaJSON, &tz, &desc, &enabled, &lastRunAt, &lastError, &createdAt, &completedAt, + &modelTier, &maxBudget, &triggerSource, &triggerQuery, &lastTriggerID, ) if err != nil { if err == sql.ErrNoRows { @@ -160,6 +164,15 @@ func scanSchedule(row scannable) (*Schedule, error) { s.Enabled = v } s.LastError = lastError.String + s.ModelTier = modelTier.String + if maxBudget.Valid { + s.MaxBudgetUSD = maxBudget.Float64 + } + s.TriggerSource = triggerSource.String + s.TriggerQuery = triggerQuery.String + if lastTriggerID.Valid { + s.LastTriggerID = lastTriggerID.Int64 + } if scheduledAt.Valid { t, err := parseTime(scheduledAt.String) @@ -211,6 +224,30 @@ func scanSchedules(rows *sql.Rows) ([]Schedule, error) { return result, rows.Err() } +func ListEnabledReactive(db *store.DB, source string) ([]Schedule, error) { + rows, err := db.Query( + db.Rebind(`SELECT id, type, cron_expr, scheduled_at, task, channel, channel_meta, timezone, description, enabled, last_run_at, last_error, created_at, completed_at, model_tier, max_budget_usd, trigger_source, trigger_query, last_trigger_id + FROM schedules WHERE type = 'reactive' AND trigger_source = ? AND enabled = ? ORDER BY created_at`), + source, 1, + ) + if err != nil { + return nil, fmt.Errorf("list reactive schedules: %w", err) + } + defer rows.Close() + return scanSchedules(rows) +} + +func UpdateLastTriggerID(db *store.DB, id int64, lastTriggerID int64) error { + _, err := db.Exec( + db.Rebind("UPDATE schedules SET last_trigger_id = ? WHERE id = ?"), + lastTriggerID, id, + ) + if err != nil { + return fmt.Errorf("update last_trigger_id: %w", err) + } + return nil +} + func parseTime(s string) (*time.Time, error) { for _, f := range []string{ timeFormat, diff --git a/service/scheduler/store_test.go b/service/scheduler/store_test.go index b40ec91f..df90fc1f 100644 --- a/service/scheduler/store_test.go +++ b/service/scheduler/store_test.go @@ -139,6 +139,79 @@ func TestListDueOneShot(t *testing.T) { } } +func TestListEnabledReactive(t *testing.T) { + db := openTestDB(t) + + // Create a reactive schedule. + _, err := Create(db, &Schedule{ + Type: Reactive, + TriggerSource: "gmail", + TriggerQuery: "from_addr LIKE '%@acme.com%'", + Task: "summarize", + Channel: "test", + ChannelMeta: ChannelMeta{BotToken: "tok", OwnerID: 1}, + Timezone: "UTC", + }) + if err != nil { + t.Fatalf("create reactive: %v", err) + } + + // Create a non-matching reactive (different source). + _, err = Create(db, &Schedule{ + Type: Reactive, + TriggerSource: "whatsapp", + TriggerQuery: "sender_name = 'Bob'", + Task: "reply", + Channel: "test", + ChannelMeta: ChannelMeta{BotToken: "tok", OwnerID: 1}, + Timezone: "UTC", + }) + if err != nil { + t.Fatalf("create reactive whatsapp: %v", err) + } + + // Only gmail reactive should be returned. + list, err := ListEnabledReactive(db, "gmail") + if err != nil { + t.Fatalf("list reactive: %v", err) + } + if len(list) != 1 { + t.Fatalf("got %d, want 1", len(list)) + } + if list[0].TriggerSource != "gmail" { + t.Errorf("source: got %q, want gmail", list[0].TriggerSource) + } +} + +func TestUpdateLastTriggerID(t *testing.T) { + db := openTestDB(t) + + id, err := Create(db, &Schedule{ + Type: Reactive, + TriggerSource: "gmail", + TriggerQuery: "from_addr = 'x'", + Task: "test", + Channel: "test", + ChannelMeta: ChannelMeta{BotToken: "tok", OwnerID: 1}, + Timezone: "UTC", + }) + if err != nil { + t.Fatalf("create: %v", err) + } + + if err := UpdateLastTriggerID(db, id, 42); err != nil { + t.Fatalf("update: %v", err) + } + + got, err := Get(db, id) + if err != nil { + t.Fatalf("get: %v", err) + } + if got.LastTriggerID != 42 { + t.Errorf("last_trigger_id: got %d, want 42", got.LastTriggerID) + } +} + func TestMarkCompleted(t *testing.T) { db := openTestDB(t) diff --git a/service/scheduler/trigger.go b/service/scheduler/trigger.go new file mode 100644 index 00000000..98963174 --- /dev/null +++ b/service/scheduler/trigger.go @@ -0,0 +1,163 @@ +package scheduler + +import ( + "fmt" + "regexp" + "strings" + + "github.com/73ai/openbotkit/store" +) + +var triggerTemplates = map[string]string{ + "gmail": `SELECT id, subject, from_addr, date FROM emails WHERE id > ? AND (%s) ORDER BY id LIMIT 50`, + "whatsapp": `SELECT id, message_id, sender_name, text, timestamp FROM whatsapp_messages WHERE id > ? AND (%s) ORDER BY id LIMIT 50`, + "imessage": `SELECT id, guid, sender_id, text, date_utc FROM imessage_messages WHERE id > ? AND (%s) ORDER BY id LIMIT 50`, + "applenotes": `SELECT id, title, folder, modified_at FROM applenotes_notes WHERE id > ? AND (%s) ORDER BY id LIMIT 50`, +} + +func ValidateTriggerSource(source string) error { + if _, ok := triggerTemplates[source]; !ok { + return fmt.Errorf("unknown trigger source %q; supported: gmail, whatsapp, imessage, applenotes", source) + } + return nil +} + +// allowedColumns defines which column names are permitted per trigger source. +var allowedColumns = map[string]map[string]bool{ + "gmail": {"SUBJECT": true, "FROM_ADDR": true, "DATE": true}, + "whatsapp": {"MESSAGE_ID": true, "SENDER_NAME": true, "TEXT": true, "TIMESTAMP": true}, + "imessage": {"GUID": true, "SENDER_ID": true, "TEXT": true, "DATE_UTC": true}, + "applenotes": {"TITLE": true, "FOLDER": true, "MODIFIED_AT": true}, +} + +// allowedKeywords are SQL keywords permitted in trigger WHERE clauses. +var allowedKeywords = map[string]bool{ + "AND": true, "OR": true, "NOT": true, + "LIKE": true, "GLOB": true, "ESCAPE": true, + "IN": true, "BETWEEN": true, + "IS": true, "NULL": true, +} + +var identPattern = regexp.MustCompile(`[a-zA-Z_][a-zA-Z0-9_]*`) + +// ValidateTriggerQuery validates that a trigger WHERE clause only uses allowed +// columns and keywords. Uses an allowlist approach instead of a denylist +// to prevent SQL injection via UNION, subqueries, functions, etc. +func ValidateTriggerQuery(source, query string) error { + if strings.TrimSpace(query) == "" { + return fmt.Errorf("trigger query must not be empty") + } + if strings.Contains(query, "--") { + return fmt.Errorf("trigger query must not contain SQL comments") + } + if strings.Contains(query, ";") { + return fmt.Errorf("trigger query must not contain semicolons") + } + if strings.Count(query, "(") != strings.Count(query, ")") { + return fmt.Errorf("trigger query has unbalanced parentheses") + } + + cols := allowedColumns[source] + if cols == nil { + return fmt.Errorf("unknown trigger source %q", source) + } + + // Strip string literals to avoid false positives on keywords inside strings. + stripped := stripStringLiterals(query) + + // Every identifier must be an allowed column or SQL keyword. + idents := identPattern.FindAllString(stripped, -1) + for _, ident := range idents { + upper := strings.ToUpper(ident) + if cols[upper] || allowedKeywords[upper] { + continue + } + return fmt.Errorf("trigger query contains disallowed identifier %q", ident) + } + + return nil +} + +// stripStringLiterals removes content between single quotes (SQL string literals). +func stripStringLiterals(s string) string { + var b strings.Builder + inStr := false + for i := 0; i < len(s); i++ { + if s[i] == '\'' { + if inStr && i+1 < len(s) && s[i+1] == '\'' { + i++ // skip escaped quote ('') + continue + } + inStr = !inStr + continue + } + if !inStr { + b.WriteByte(s[i]) + } + } + return b.String() +} + +func BuildTriggerQuery(source, whereClause string, lastTriggerID int64) (string, []any, error) { + tmpl, ok := triggerTemplates[source] + if !ok { + return "", nil, fmt.Errorf("unknown trigger source %q", source) + } + q := fmt.Sprintf(tmpl, whereClause) + return q, []any{lastTriggerID}, nil +} + +type TriggerMatch struct { + MaxID int64 + Rows []map[string]string +} + +func CheckTrigger(db *store.DB, source, whereClause string, lastTriggerID int64) (*TriggerMatch, error) { + query, args, err := BuildTriggerQuery(source, whereClause, lastTriggerID) + if err != nil { + return nil, err + } + + rows, err := db.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("trigger query: %w", err) + } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + return nil, fmt.Errorf("trigger columns: %w", err) + } + + var match TriggerMatch + for rows.Next() { + vals := make([]any, len(cols)) + ptrs := make([]any, len(cols)) + for i := range vals { + ptrs[i] = &vals[i] + } + if err := rows.Scan(ptrs...); err != nil { + return nil, fmt.Errorf("trigger scan: %w", err) + } + row := make(map[string]string, len(cols)) + for i, col := range cols { + row[col] = fmt.Sprintf("%v", vals[i]) + } + if id, ok := row["id"]; ok { + var n int64 + fmt.Sscanf(id, "%d", &n) + if n > match.MaxID { + match.MaxID = n + } + } + match.Rows = append(match.Rows, row) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("trigger rows: %w", err) + } + + if len(match.Rows) == 0 { + return nil, nil + } + return &match, nil +} diff --git a/service/scheduler/trigger_test.go b/service/scheduler/trigger_test.go new file mode 100644 index 00000000..d888f747 --- /dev/null +++ b/service/scheduler/trigger_test.go @@ -0,0 +1,124 @@ +package scheduler + +import "testing" + +func TestValidateTriggerSource(t *testing.T) { + for _, src := range []string{"gmail", "whatsapp", "imessage", "applenotes"} { + if err := ValidateTriggerSource(src); err != nil { + t.Errorf("ValidateTriggerSource(%q) = %v, want nil", src, err) + } + } + if err := ValidateTriggerSource("unknown"); err == nil { + t.Error("expected error for unknown source") + } +} + +func TestValidateTriggerQuery(t *testing.T) { + tests := []struct { + name string + source string + query string + wantErr bool + }{ + {"valid simple", "gmail", "from_addr LIKE '%@acme.com%'", false}, + {"valid AND", "gmail", "from_addr = 'x' AND subject LIKE '%urgent%'", false}, + {"valid OR parens", "gmail", "(from_addr LIKE '%@a.com%') OR (subject LIKE '%b%')", false}, + {"valid IS NOT NULL", "gmail", "from_addr IS NOT NULL", false}, + {"valid IN", "gmail", "from_addr IN ('a@b.com', 'c@d.com')", false}, + {"valid whatsapp", "whatsapp", "sender_name = 'Alice' AND text LIKE '%meeting%'", false}, + {"empty", "gmail", "", true}, + {"semicolon", "gmail", "1=1; DROP TABLE emails", true}, + {"comment", "gmail", "from_addr = 'x' -- comment", true}, + {"unbalanced parens", "gmail", "from_addr LIKE '(' AND (subject = 'x'", true}, + + // SQL injection attempts — blocked by allowlist + {"union select", "gmail", "1=1 UNION SELECT sql FROM sqlite_master", true}, + {"subquery", "gmail", "from_addr IN (SELECT from_addr FROM emails)", true}, + {"drop", "gmail", "DROP TABLE emails", true}, + {"delete", "gmail", "DELETE FROM emails", true}, + {"insert", "gmail", "INSERT INTO emails VALUES(1)", true}, + {"update", "gmail", "UPDATE emails SET x=1", true}, + {"attach", "gmail", "ATTACH DATABASE '/tmp/x' AS x", true}, + {"pragma", "gmail", "PRAGMA table_info(emails)", true}, + {"load_extension", "gmail", "load_extension('/tmp/evil')", true}, + {"unknown column", "gmail", "nonexistent_col = 'x'", true}, + {"function call", "gmail", "lower(from_addr) = 'x'", true}, + + // Keywords inside string literals are safe + {"keyword in string", "gmail", "subject LIKE '%DROP TABLE%'", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateTriggerQuery(tt.source, tt.query) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateTriggerQuery(%q, %q) = %v, wantErr %v", tt.source, tt.query, err, tt.wantErr) + } + }) + } +} + +func TestBuildTriggerQuery(t *testing.T) { + q, args, err := BuildTriggerQuery("gmail", "from_addr LIKE '%@test.com%'", 5) + if err != nil { + t.Fatal(err) + } + if len(args) != 1 || args[0] != int64(5) { + t.Errorf("args = %v, want [5]", args) + } + if q == "" { + t.Error("expected non-empty query") + } + + _, _, err = BuildTriggerQuery("nonexistent", "x=1", 0) + if err == nil { + t.Error("expected error for unknown source") + } +} + +func TestCheckTrigger(t *testing.T) { + db := openTestDB(t) + defer db.Close() + + // Create a test emails table matching the gmail trigger template. + _, err := db.Exec(`CREATE TABLE emails ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + subject TEXT, + from_addr TEXT, + date TEXT + )`) + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec(`INSERT INTO emails (subject, from_addr, date) VALUES + ('Newsletter', 'news@example.com', '2026-01-01'), + ('Q1 Planning', 'boss@acme.com', '2026-01-02'), + ('Ticket resolved', 'support@vendor.io', '2026-01-03')`) + if err != nil { + t.Fatal(err) + } + + match, err := CheckTrigger(db, "gmail", "from_addr LIKE '%@acme.com%'", 0) + if err != nil { + t.Fatal(err) + } + if match == nil { + t.Fatal("expected match, got nil") + } + if len(match.Rows) != 1 { + t.Errorf("got %d rows, want 1", len(match.Rows)) + } + if match.MaxID != 2 { + t.Errorf("MaxID = %d, want 2", match.MaxID) + } + + // Watermark past the match should return nil. + match, err = CheckTrigger(db, "gmail", "from_addr LIKE '%@acme.com%'", 2) + if err != nil { + t.Fatal(err) + } + if match != nil { + t.Errorf("expected nil match after watermark, got %d rows", len(match.Rows)) + } +} + diff --git a/service/scheduler/types.go b/service/scheduler/types.go index fe7cafbb..f4038ffc 100644 --- a/service/scheduler/types.go +++ b/service/scheduler/types.go @@ -7,6 +7,7 @@ type ScheduleType string const ( Recurring ScheduleType = "recurring" OneShot ScheduleType = "one_shot" + Reactive ScheduleType = "reactive" ) type ChannelMeta struct { @@ -17,18 +18,23 @@ type ChannelMeta struct { } type Schedule struct { - ID int64 - Type ScheduleType - CronExpr string - ScheduledAt *time.Time - Task string - Channel string - ChannelMeta ChannelMeta - Timezone string - Description string - Enabled bool - LastRunAt *time.Time - LastError string - CreatedAt time.Time - CompletedAt *time.Time + ID int64 + Type ScheduleType + CronExpr string + ScheduledAt *time.Time + Task string + Channel string + ChannelMeta ChannelMeta + Timezone string + Description string + Enabled bool + LastRunAt *time.Time + LastError string + CreatedAt time.Time + CompletedAt *time.Time + ModelTier string + MaxBudgetUSD float64 + TriggerSource string + TriggerQuery string + LastTriggerID int64 } diff --git a/service/tasks/cleanup.go b/service/tasks/cleanup.go new file mode 100644 index 00000000..70799b08 --- /dev/null +++ b/service/tasks/cleanup.go @@ -0,0 +1,22 @@ +package tasks + +import ( + "log/slog" + "time" + + "github.com/73ai/openbotkit/store" +) + +const retentionDays = 7 + +func Cleanup(db *store.DB) { + cutoff := time.Now().UTC().Add(-retentionDays * 24 * time.Hour) + n, err := DeleteOlderThan(db, cutoff) + if err != nil { + slog.Warn("tasks cleanup failed", "error", err) + return + } + if n > 0 { + slog.Info("tasks cleanup", "deleted", n) + } +} diff --git a/service/tasks/schema.go b/service/tasks/schema.go new file mode 100644 index 00000000..d57a3e6b --- /dev/null +++ b/service/tasks/schema.go @@ -0,0 +1,42 @@ +package tasks + +import "github.com/73ai/openbotkit/store" + +const schemaSQLite = ` +CREATE TABLE IF NOT EXISTS tasks ( + id TEXT PRIMARY KEY, + task TEXT NOT NULL, + agent TEXT NOT NULL, + status TEXT NOT NULL, + started_at TEXT NOT NULL, + done_at TEXT, + output TEXT, + error TEXT +); +CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status); +CREATE INDEX IF NOT EXISTS idx_tasks_started_at ON tasks(started_at); +` + +const schemaPostgres = ` +CREATE TABLE IF NOT EXISTS tasks ( + id TEXT PRIMARY KEY, + task TEXT NOT NULL, + agent TEXT NOT NULL, + status TEXT NOT NULL, + started_at TIMESTAMPTZ NOT NULL, + done_at TIMESTAMPTZ, + output TEXT, + error TEXT +); +CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status); +CREATE INDEX IF NOT EXISTS idx_tasks_started_at ON tasks(started_at); +` + +func Migrate(db *store.DB) error { + schema := schemaSQLite + if db.IsPostgres() { + schema = schemaPostgres + } + _, err := db.Exec(schema) + return err +} diff --git a/service/tasks/store.go b/service/tasks/store.go new file mode 100644 index 00000000..5fddc3e7 --- /dev/null +++ b/service/tasks/store.go @@ -0,0 +1,154 @@ +package tasks + +import ( + "database/sql" + "fmt" + "time" + + "github.com/73ai/openbotkit/store" +) + +const timeFormat = "2006-01-02T15:04:05Z" + +type TaskRecord struct { + ID string + Task string + Agent string + Status string + StartedAt time.Time + DoneAt *time.Time + Output string + Error string +} + +func Insert(db *store.DB, r *TaskRecord) error { + _, err := db.Exec( + db.Rebind(`INSERT INTO tasks (id, task, agent, status, started_at) VALUES (?, ?, ?, ?, ?)`), + r.ID, r.Task, r.Agent, r.Status, r.StartedAt.UTC().Format(timeFormat), + ) + if err != nil { + return fmt.Errorf("insert task: %w", err) + } + return nil +} + +func SetCompleted(db *store.DB, id, output string) error { + _, err := db.Exec( + db.Rebind(`UPDATE tasks SET status = 'completed', output = ?, done_at = ? WHERE id = ?`), + output, time.Now().UTC().Format(timeFormat), id, + ) + if err != nil { + return fmt.Errorf("set task completed: %w", err) + } + return nil +} + +func SetFailed(db *store.DB, id, errMsg string) error { + _, err := db.Exec( + db.Rebind(`UPDATE tasks SET status = 'failed', error = ?, done_at = ? WHERE id = ?`), + errMsg, time.Now().UTC().Format(timeFormat), id, + ) + if err != nil { + return fmt.Errorf("set task failed: %w", err) + } + return nil +} + +func Get(db *store.DB, id string) (*TaskRecord, error) { + row := db.QueryRow( + db.Rebind(`SELECT id, task, agent, status, started_at, done_at, output, error FROM tasks WHERE id = ?`), + id, + ) + return scanTask(row) +} + +func List(db *store.DB) ([]*TaskRecord, error) { + rows, err := db.Query(`SELECT id, task, agent, status, started_at, done_at, output, error FROM tasks ORDER BY started_at DESC`) + if err != nil { + return nil, fmt.Errorf("list tasks: %w", err) + } + defer rows.Close() + return scanTasks(rows) +} + +func CountRunning(db *store.DB) (int, error) { + var count int + err := db.QueryRow(`SELECT COUNT(*) FROM tasks WHERE status = 'running'`).Scan(&count) + if err != nil { + return 0, fmt.Errorf("count running: %w", err) + } + return count, nil +} + +func DeleteOlderThan(db *store.DB, before time.Time) (int64, error) { + res, err := db.Exec( + db.Rebind(`DELETE FROM tasks WHERE status IN ('completed', 'failed') AND done_at < ?`), + before.UTC().Format(timeFormat), + ) + if err != nil { + return 0, fmt.Errorf("delete old tasks: %w", err) + } + return res.RowsAffected() +} + +type scannable interface { + Scan(dest ...any) error +} + +func scanTask(row scannable) (*TaskRecord, error) { + var r TaskRecord + var startedAt, doneAt, output, errMsg sql.NullString + + err := row.Scan(&r.ID, &r.Task, &r.Agent, &r.Status, &startedAt, &doneAt, &output, &errMsg) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, fmt.Errorf("scan task: %w", err) + } + + r.Output = output.String + r.Error = errMsg.String + + if startedAt.Valid { + t, err := parseTime(startedAt.String) + if err != nil { + return nil, fmt.Errorf("parse started_at: %w", err) + } + r.StartedAt = *t + } + if doneAt.Valid { + t, err := parseTime(doneAt.String) + if err != nil { + return nil, fmt.Errorf("parse done_at: %w", err) + } + r.DoneAt = t + } + + return &r, nil +} + +func scanTasks(rows *sql.Rows) ([]*TaskRecord, error) { + var result []*TaskRecord + for rows.Next() { + r, err := scanTask(rows) + if err != nil { + return nil, err + } + result = append(result, r) + } + return result, rows.Err() +} + +func parseTime(s string) (*time.Time, error) { + for _, f := range []string{ + timeFormat, + "2006-01-02 15:04:05", + time.RFC3339, + } { + if t, err := time.Parse(f, s); err == nil { + return &t, nil + } + } + return nil, fmt.Errorf("unrecognised time format: %q", s) +} diff --git a/service/tasks/store_test.go b/service/tasks/store_test.go new file mode 100644 index 00000000..6ad625c5 --- /dev/null +++ b/service/tasks/store_test.go @@ -0,0 +1,184 @@ +package tasks + +import ( + "testing" + "time" + + "github.com/73ai/openbotkit/store" +) + +func openTestDB(t *testing.T) *store.DB { + t.Helper() + db, err := store.Open(store.SQLiteConfig(":memory:")) + if err != nil { + t.Fatalf("open db: %v", err) + } + if err := Migrate(db); err != nil { + db.Close() + t.Fatalf("migrate: %v", err) + } + t.Cleanup(func() { db.Close() }) + return db +} + +func TestInsertAndGet(t *testing.T) { + db := openTestDB(t) + now := time.Now().UTC().Truncate(time.Second) + r := &TaskRecord{ + ID: "t1", Task: "do stuff", Agent: "claude", + Status: "running", StartedAt: now, + } + if err := Insert(db, r); err != nil { + t.Fatalf("Insert: %v", err) + } + got, err := Get(db, "t1") + if err != nil { + t.Fatalf("Get: %v", err) + } + if got == nil { + t.Fatal("expected non-nil result") + } + if got.ID != "t1" || got.Task != "do stuff" || got.Agent != "claude" || got.Status != "running" { + t.Errorf("unexpected record: %+v", got) + } + if got.StartedAt.Truncate(time.Second) != now { + t.Errorf("StartedAt = %v, want %v", got.StartedAt, now) + } +} + +func TestSetCompleted(t *testing.T) { + db := openTestDB(t) + Insert(db, &TaskRecord{ID: "t1", Task: "t", Agent: "claude", Status: "running", StartedAt: time.Now()}) + if err := SetCompleted(db, "t1", "result text"); err != nil { + t.Fatalf("SetCompleted: %v", err) + } + got, _ := Get(db, "t1") + if got.Status != "completed" { + t.Errorf("Status = %q", got.Status) + } + if got.Output != "result text" { + t.Errorf("Output = %q", got.Output) + } + if got.DoneAt == nil { + t.Error("DoneAt should be set") + } +} + +func TestSetFailed(t *testing.T) { + db := openTestDB(t) + Insert(db, &TaskRecord{ID: "t1", Task: "t", Agent: "claude", Status: "running", StartedAt: time.Now()}) + if err := SetFailed(db, "t1", "timeout"); err != nil { + t.Fatalf("SetFailed: %v", err) + } + got, _ := Get(db, "t1") + if got.Status != "failed" { + t.Errorf("Status = %q", got.Status) + } + if got.Error != "timeout" { + t.Errorf("Error = %q", got.Error) + } +} + +func TestGetNotFound(t *testing.T) { + db := openTestDB(t) + got, err := Get(db, "nonexistent") + if err != nil { + t.Fatalf("Get: %v", err) + } + if got != nil { + t.Errorf("expected nil, got %+v", got) + } +} + +func TestListOrdering(t *testing.T) { + db := openTestDB(t) + t1 := time.Now().UTC().Add(-2 * time.Hour) + t2 := time.Now().UTC().Add(-1 * time.Hour) + t3 := time.Now().UTC() + Insert(db, &TaskRecord{ID: "old", Task: "t", Agent: "a", Status: "completed", StartedAt: t1}) + Insert(db, &TaskRecord{ID: "mid", Task: "t", Agent: "a", Status: "running", StartedAt: t2}) + Insert(db, &TaskRecord{ID: "new", Task: "t", Agent: "a", Status: "running", StartedAt: t3}) + + list, err := List(db) + if err != nil { + t.Fatalf("List: %v", err) + } + if len(list) != 3 { + t.Fatalf("got %d tasks", len(list)) + } + if list[0].ID != "new" || list[1].ID != "mid" || list[2].ID != "old" { + t.Errorf("order: %s, %s, %s", list[0].ID, list[1].ID, list[2].ID) + } +} + +func TestCountRunning(t *testing.T) { + db := openTestDB(t) + Insert(db, &TaskRecord{ID: "t1", Task: "t", Agent: "a", Status: "running", StartedAt: time.Now()}) + Insert(db, &TaskRecord{ID: "t2", Task: "t", Agent: "a", Status: "completed", StartedAt: time.Now()}) + Insert(db, &TaskRecord{ID: "t3", Task: "t", Agent: "a", Status: "running", StartedAt: time.Now()}) + + count, err := CountRunning(db) + if err != nil { + t.Fatalf("CountRunning: %v", err) + } + if count != 2 { + t.Errorf("count = %d, want 2", count) + } +} + +func TestDeleteOlderThan(t *testing.T) { + db := openTestDB(t) + old := time.Now().UTC().Add(-10 * 24 * time.Hour) + recent := time.Now().UTC() + Insert(db, &TaskRecord{ID: "old", Task: "t", Agent: "a", Status: "completed", StartedAt: old}) + SetCompleted(db, "old", "done") + // Manually set done_at to old time + db.Exec(db.Rebind(`UPDATE tasks SET done_at = ? WHERE id = ?`), old.Format(timeFormat), "old") + Insert(db, &TaskRecord{ID: "new", Task: "t", Agent: "a", Status: "completed", StartedAt: recent}) + SetCompleted(db, "new", "done") + + cutoff := time.Now().UTC().Add(-7 * 24 * time.Hour) + n, err := DeleteOlderThan(db, cutoff) + if err != nil { + t.Fatalf("DeleteOlderThan: %v", err) + } + if n != 1 { + t.Errorf("deleted %d, want 1", n) + } + got, _ := Get(db, "old") + if got != nil { + t.Error("old task should be deleted") + } + got, _ = Get(db, "new") + if got == nil { + t.Error("new task should still exist") + } +} + +func TestCleanup(t *testing.T) { + db := openTestDB(t) + old := time.Now().UTC().Add(-10 * 24 * time.Hour) + Insert(db, &TaskRecord{ID: "old1", Task: "t", Agent: "a", Status: "completed", StartedAt: old}) + SetCompleted(db, "old1", "done") + db.Exec(db.Rebind(`UPDATE tasks SET done_at = ? WHERE id = ?`), old.Format(timeFormat), "old1") + + Insert(db, &TaskRecord{ID: "recent", Task: "t", Agent: "a", Status: "running", StartedAt: time.Now().UTC()}) + + Cleanup(db) + + got, _ := Get(db, "old1") + if got != nil { + t.Error("old task should be cleaned up") + } + got, _ = Get(db, "recent") + if got == nil { + t.Error("recent task should still exist") + } +} + +func TestMigrateIdempotent(t *testing.T) { + db := openTestDB(t) + if err := Migrate(db); err != nil { + t.Fatalf("second Migrate: %v", err) + } +} diff --git a/spectest/local_fixture.go b/spectest/local_fixture.go index 0efeeadf..c87944a7 100644 --- a/spectest/local_fixture.go +++ b/spectest/local_fixture.go @@ -183,7 +183,7 @@ func availableProviders(t *testing.T) []providerCase { func createSourceDirs(t *testing.T, dir string) { t.Helper() - for _, src := range []string{"gmail", "whatsapp", "history", "user_memory", "applenotes", "contacts", "scheduler"} { + for _, src := range []string{"gmail", "whatsapp", "history", "user_memory", "applenotes", "contacts", "scheduler", "tasks"} { if err := os.MkdirAll(filepath.Join(dir, src), 0700); err != nil { t.Fatalf("mkdir %s: %v", src, err) } diff --git a/spectest/reactive_trigger_test.go b/spectest/reactive_trigger_test.go new file mode 100644 index 00000000..24b5a072 --- /dev/null +++ b/spectest/reactive_trigger_test.go @@ -0,0 +1,295 @@ +package spectest + +import ( + "context" + "database/sql" + "path/filepath" + "testing" + "time" + + "github.com/riverqueue/river" + "github.com/riverqueue/river/riverdriver/riversqlite" + "github.com/riverqueue/river/rivermigrate" + + "github.com/73ai/openbotkit/channel" + "github.com/73ai/openbotkit/config" + "github.com/73ai/openbotkit/daemon" + "github.com/73ai/openbotkit/daemon/jobs" + "github.com/73ai/openbotkit/service/scheduler" + "github.com/73ai/openbotkit/store" +) + +func TestSpec_ReactiveEmailTrigger(t *testing.T) { + fx := NewLocalFixture(t) + + // Seed emails — only the boss email should trigger. + fx.GivenEmails(t, []Email{ + {From: "newsletter@news.com", To: "me@example.com", + Subject: "Weekly Digest", Body: "Here are this week's top stories..."}, + {From: "boss@acme.com", To: "me@example.com", + Subject: "Q1 Planning", Body: "Please review the Q1 targets before our meeting tomorrow."}, + {From: "support@vendor.io", To: "me@example.com", + Subject: "Ticket #1234 resolved", Body: "Your support ticket has been resolved."}, + }) + + dir := fx.dir + schedDBPath := filepath.Join(dir, "scheduler", "data.db") + jobsDBPath := filepath.Join(dir, "jobs.db") + gmailDBPath := filepath.Join(dir, "gmail", "data.db") + + // Create scheduler DB with reactive schedule. + sdb, err := store.Open(store.SQLiteConfig(schedDBPath)) + if err != nil { + t.Fatalf("open sched db: %v", err) + } + if err := scheduler.Migrate(sdb); err != nil { + t.Fatalf("migrate sched: %v", err) + } + + meta := scheduler.ChannelMeta{BotToken: "test", OwnerID: 1} + + scheduleID, err := scheduler.Create(sdb, &scheduler.Schedule{ + Type: scheduler.Reactive, + TriggerSource: "gmail", + TriggerQuery: "from_addr LIKE '%@acme.com%'", + Task: "Summarize this email in one sentence.", + Channel: "test", + ChannelMeta: meta, + Timezone: "UTC", + Description: "Summarize emails from Acme Corp", + ModelTier: "fast", + MaxBudgetUSD: 1.0, + }) + if err != nil { + t.Fatalf("create schedule: %v", err) + } + sdb.Close() + + cfg := &config.Config{ + Scheduler: &config.SchedulerConfig{ + Storage: config.StorageConfig{Driver: "sqlite", DSN: schedDBPath}, + }, + } + + // Set up River. + jobsDB, err := sql.Open("sqlite", jobsDBPath) + if err != nil { + t.Fatalf("open jobs db: %v", err) + } + defer jobsDB.Close() + jobsDB.SetMaxOpenConns(1) + + driver := riversqlite.New(jobsDB) + migrator, err := rivermigrate.New(driver, nil) + if err != nil { + t.Fatalf("create migrator: %v", err) + } + if _, err := migrator.Migrate(context.Background(), rivermigrate.DirectionUp, nil); err != nil { + t.Fatalf("run migrations: %v", err) + } + + // Set up worker with real LLM and mock pusher. + pusher := &capturePusher{} + worker := &jobs.ScheduledTaskWorker{ + Cfg: cfg, + RunAgentFunc: makeAgentRunner(fx.Provider, fx.Model), + MakePusher: func(_ string, _ scheduler.ChannelMeta) (channel.Pusher, error) { + return pusher, nil + }, + } + + workers := river.NewWorkers() + river.AddWorker(workers, worker) + + riverClient, err := river.NewClient(driver, &river.Config{ + Queues: map[string]river.QueueConfig{river.QueueDefault: {MaxWorkers: 1}}, + Workers: workers, + }) + if err != nil { + t.Fatalf("create river client: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + if err := riverClient.Start(ctx); err != nil { + t.Fatalf("start river: %v", err) + } + defer riverClient.Stop(context.Background()) + + // Execute reactive trigger check. + notifier := daemon.NewSyncNotifier() + sched := daemon.NewScheduler(cfg, riverClient, jobsDB, notifier) + + gmailDB, err := store.Open(store.SQLiteConfig(gmailDBPath)) + if err != nil { + t.Fatalf("open gmail db: %v", err) + } + defer gmailDB.Close() + + if err := sched.CheckReactiveTriggersForTest(ctx, "gmail", gmailDB); err != nil { + t.Fatalf("reactive check: %v", err) + } + + // Wait for the pusher to receive a message. + deadline := time.After(90 * time.Second) + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-deadline: + t.Fatal("timed out waiting for reactive task result") + case <-ticker.C: + msgs := pusher.Messages() + if len(msgs) == 0 { + continue + } + t.Logf("pushed message: %s", msgs[0]) + + // Verify schedule watermark was updated. + sdb2, err := store.Open(store.SQLiteConfig(schedDBPath)) + if err != nil { + t.Fatalf("reopen sched db: %v", err) + } + defer sdb2.Close() + + s, err := scheduler.Get(sdb2, scheduleID) + if err != nil { + t.Fatalf("get schedule: %v", err) + } + if s.LastTriggerID == 0 { + t.Error("expected last_trigger_id to be updated (watermark)") + } + if s.LastRunAt == nil { + t.Error("expected last_run_at to be set") + } + if s.LastError != "" { + t.Errorf("expected no last_error, got %q", s.LastError) + } + + // Verify re-running trigger doesn't re-fire (watermark prevents it). + if err := sched.CheckReactiveTriggersForTest(ctx, "gmail", gmailDB); err != nil { + t.Fatalf("second reactive check: %v", err) + } + time.Sleep(2 * time.Second) + msgs2 := pusher.Messages() + if len(msgs2) > 1 { + t.Errorf("expected no additional messages after watermark update, got %d total", len(msgs2)) + } + + return + } + } +} + +// TestSpec_ReactiveNoMatchDoesNotFire verifies no job is enqueued when nothing matches. +func TestSpec_ReactiveNoMatchDoesNotFire(t *testing.T) { + fx := NewLocalFixture(t) + + fx.GivenEmails(t, []Email{ + {From: "newsletter@news.com", To: "me@example.com", + Subject: "Weekly Digest", Body: "Here are this week's top stories..."}, + }) + + dir := fx.dir + schedDBPath := filepath.Join(dir, "scheduler", "data.db") + jobsDBPath := filepath.Join(dir, "jobs.db") + gmailDBPath := filepath.Join(dir, "gmail", "data.db") + + sdb, err := store.Open(store.SQLiteConfig(schedDBPath)) + if err != nil { + t.Fatalf("open sched db: %v", err) + } + if err := scheduler.Migrate(sdb); err != nil { + t.Fatalf("migrate sched: %v", err) + } + + meta := scheduler.ChannelMeta{BotToken: "test", OwnerID: 1} + + _, err = scheduler.Create(sdb, &scheduler.Schedule{ + Type: scheduler.Reactive, + TriggerSource: "gmail", + TriggerQuery: "from_addr LIKE '%@doesnotexist.com%'", + Task: "Should not fire", + Channel: "test", + ChannelMeta: meta, + Timezone: "UTC", + Description: "No-match test", + ModelTier: "fast", + }) + if err != nil { + t.Fatalf("create schedule: %v", err) + } + sdb.Close() + + cfg := &config.Config{ + Scheduler: &config.SchedulerConfig{ + Storage: config.StorageConfig{Driver: "sqlite", DSN: schedDBPath}, + }, + } + + jobsDB, err := sql.Open("sqlite", jobsDBPath) + if err != nil { + t.Fatalf("open jobs db: %v", err) + } + defer jobsDB.Close() + jobsDB.SetMaxOpenConns(1) + + driver := riversqlite.New(jobsDB) + migrator, err := rivermigrate.New(driver, nil) + if err != nil { + t.Fatalf("create migrator: %v", err) + } + if _, err := migrator.Migrate(context.Background(), rivermigrate.DirectionUp, nil); err != nil { + t.Fatalf("run migrations: %v", err) + } + + pusher := &capturePusher{} + worker := &jobs.ScheduledTaskWorker{ + Cfg: cfg, + MakePusher: func(_ string, _ scheduler.ChannelMeta) (channel.Pusher, error) { + return pusher, nil + }, + } + + workers := river.NewWorkers() + river.AddWorker(workers, worker) + + riverClient, err := river.NewClient(driver, &river.Config{ + Queues: map[string]river.QueueConfig{river.QueueDefault: {MaxWorkers: 1}}, + Workers: workers, + }) + if err != nil { + t.Fatalf("create river client: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := riverClient.Start(ctx); err != nil { + t.Fatalf("start river: %v", err) + } + defer riverClient.Stop(context.Background()) + + notifier := daemon.NewSyncNotifier() + sched := daemon.NewScheduler(cfg, riverClient, jobsDB, notifier) + + gmailDB, err := store.Open(store.SQLiteConfig(gmailDBPath)) + if err != nil { + t.Fatalf("open gmail db: %v", err) + } + defer gmailDB.Close() + + if err := sched.CheckReactiveTriggersForTest(ctx, "gmail", gmailDB); err != nil { + t.Fatalf("reactive check: %v", err) + } + + // Wait and verify no message was pushed. + time.Sleep(3 * time.Second) + msgs := pusher.Messages() + if len(msgs) != 0 { + t.Errorf("expected no messages, got %d", len(msgs)) + } +} +