diff --git a/PERFORMANCE_OPTIMIZATIONS.md b/PERFORMANCE_OPTIMIZATIONS.md deleted file mode 100644 index f452b6f..0000000 --- a/PERFORMANCE_OPTIMIZATIONS.md +++ /dev/null @@ -1,109 +0,0 @@ -# go-agent Performance Optimizations - -## Overview -This document outlines performance optimizations implemented to make go-agent faster. - -## Optimizations Implemented - -### 1. **Concurrent Memory Operations** -**Problem**: Memory retrieval and storage operations were sequential -**Solution**: Parallelize memory operations using goroutines and sync.WaitGroup -**Impact**: 2-3x faster for multi-space operations - -### 2. **LLM Response Caching** -**Problem**: Identical or similar queries trigger duplicate LLM calls -**Solution**: Implement in-memory LRU cache for LLM responses with TTL -**Impact**: Near-instant responses for cached queries (100-1000x faster) - -### 3. **MIME Type Optimization** -**Problem**: Multiple string operations and normalization passes per file -**Solution**: Pre-compute and cache normalized MIME types; use lookup tables -**Impact**: 10-50x faster file processing - -### 4. **Tool Spec Caching Improvements** -**Problem**: Tool specs were cached but cache invalidation was inefficient -**Solution**: Better cache key generation and partial invalidation -**Impact**: Reduced cache misses by 80% - -### 5. **String Builder Optimizations** -**Problem**: Multiple string concatenations causing allocations -**Solution**: Pre-allocate strings.Builder with Grow() -**Impact**: 30-50% reduction in allocations - -### 6. **Batch Processing** -**Problem**: Processing items one at a time -**Solution**: Batch similar operations together -**Impact**: Reduced overhead by 40% - -### 7. **Connection Pooling** -**Problem**: Creating new HTTP clients for each request -**Solution**: Reuse HTTP clients and connections -**Impact**: 20-30% faster external API calls - -### 8. **Lazy Model Initialization** -**Problem**: All models initialized upfront even if not used -**Solution**: Initialize models on-demand -**Impact**: Faster startup time (50-70%) - -## Benchmarking - -### Before Optimizations -``` -BenchmarkEngineRetrieve-8 500 3.2 ms/op 1.2 MB/s -BenchmarkAgentGenerate-8 20 85 ms/op -BenchmarkToolOrchestrator-8 10 150 ms/op -``` - -### After Optimizations -``` -BenchmarkEngineRetrieve-8 2000 0.8 ms/op 4.8 MB/s (4x faster) -BenchmarkAgentGenerate-8 50 35 ms/op (2.4x faster) -BenchmarkToolOrchestrator-8 30 60 ms/op (2.5x faster) -``` - -## Environment Variables - -New environment variables for tuning performance: - -- `AGENT_LLM_CACHE_SIZE` - LRU cache size for LLM responses (default: 1000) -- `AGENT_LLM_CACHE_TTL` - Cache TTL in seconds (default: 300) -- `AGENT_LLM_CACHE_PATH` - Path to cache file for persistence (default: .agent_cache.json) -- `AGENT_CONCURRENT_OPS` - Max concurrent operations (default: 10) -- `AGENT_BATCH_SIZE` - Batch size for batch operations (default: 50) - -## Usage - -No code changes required! All optimizations are backwards compatible. - -For maximum performance: -```go -os.Setenv("AGENT_LLM_CACHE_SIZE", "5000") -os.Setenv("AGENT_LLM_CACHE_TTL", "600") -os.Setenv("AGENT_CONCURRENT_OPS", "20") -``` - -## Monitoring - -Use the built-in metrics to monitor performance: -```go -if a.memory != nil && a.memory.Metrics != nil { - cacheHits := a.memory.Metrics.CacheHits() - cacheMisses := a.memory.Metrics.CacheMisses() - hitRate := float64(cacheHits) / float64(cacheHits + cacheMisses) - log.Printf("LLM Cache Hit Rate: %.2f%%", hitRate*100) -} -``` - -## Trade-offs - -- **Memory Usage**: Caching increases memory usage by ~50-100MB depending on cache size -- **Consistency**: Cached responses may be stale for rapidly changing data -- **Cold Start**: First request still requires LLM call - -## Future Optimizations - -1. Streaming responses for long-running operations -2. Speculative execution for predictable workflows -3. GPU acceleration for embedding generation -4. Distributed caching with Redis -5. Request deduplication across instances diff --git a/PERFORMANCE_SUMMARY.md b/PERFORMANCE_SUMMARY.md deleted file mode 100644 index 96960f5..0000000 --- a/PERFORMANCE_SUMMARY.md +++ /dev/null @@ -1,241 +0,0 @@ -# go-agent Performance Optimization - Complete โœ… - -## Summary - -I've successfully optimized go-agent with **significant performance improvements** across multiple critical paths. All changes are **backwards compatible** and **production-ready**. - -## ๐ŸŽฏ What Was Done - -### 1. **MIME Type Normalization - 10-50x Faster** - -**Created:** -- Pre-computed lookup tables for 20+ common file extensions -- Thread-safe LRU cache with 1000-entry capacity -- Optimized string operations to avoid allocations - -**Results:** -``` -BenchmarkNormalizeMIME-8 33,106,576 36.38 ns/op 24 B/op 1 allocs/op -``` - -**Impact:** File processing is now **10-50x faster** with **90% fewer allocations**. - ---- - -### 2. **Prompt Building Optimization - 40-60% Faster** - -**Created:** -- Pre-calculated buffer sizes based on content -- Switched from `bytes.Buffer` to `strings.Builder` with `Grow()` -- Eliminated redundant string operations - -**Results:** -``` -BenchmarkCombinePromptWithFiles_Small-8 4,257,721 282.0 ns/op 544 B/op 5 allocs/op -BenchmarkCombinePromptWithFiles_Large-8 484,650 2468 ns/op 12768 B/op 21 allocs/op -``` - -**Impact:** **40-60% fewer allocations**, scales linearly with file count. - ---- - -### 3. **LRU Cache Infrastructure** - -**Created:** `src/cache/lru_cache.go` -- Thread-safe implementation with RWMutex -- TTL support for automatic expiration -- SHA-256 key hashing for cache keys -- Comprehensive test coverage - -**Results:** -``` -BenchmarkLRUCache_Set-8 5,904,870 184.4 ns/op 149 B/op 2 allocs/op -BenchmarkLRUCache_Get-8 7,038,160 168.1 ns/op 128 B/op 2 allocs/op -BenchmarkLRUCache_ConcurrentAccess-8 4,562,347 261.5 ns/op 128 B/op 2 allocs/op -``` - -**Impact:** Ready for LLM response caching - will provide **100-1000x speedup** for repeated queries. - ---- - -### 4. **Concurrent Processing Utilities** - -**Created:** `src/concurrent/pool.go` -- Generic `ParallelMap` for concurrent transformations -- Generic `ParallelForEach` for parallel operations -- `WorkerPool` for controlled concurrency -- Context-aware cancellation - -**Impact:** Foundation for parallelizing memory operations and tool calls. - ---- - -### 5. **Tool Orchestrator Fast-Path - 64% Faster** โšก - -**Problem:** `toolOrchestrator` was making expensive LLM calls (1-3 seconds) for EVERY request, even simple questions like "What is X?" - -**Created:** -- Fast heuristic filtering in `toolOrchestrator` -- `likelyNeedsToolCall()` function to skip unnecessary LLM calls -- Pattern matching for tool keywords vs question words - -**Results:** -- **64% faster** for non-tool queries (2350ms โ†’ 850ms) -- **No regression** for actual tool requests -- **Microsecond-level filtering** instead of multi-second LLM calls - -**Impact:** Most user queries are now **2.8x faster** because they skip the expensive tool selection LLM call. - -See [TOOL_ORCHESTRATOR_OPTIMIZATION.md](./TOOL_ORCHESTRATOR_OPTIMIZATION.md) for details. - ---- - -## ๐Ÿ“ Files Created/Modified - -### New Files: -- โœ… `src/cache/lru_cache.go` - LRU cache implementation -- โœ… `src/cache/lru_cache_test.go` - Cache tests and benchmarks -- โœ… `src/concurrent/pool.go` - Concurrent utilities -- โœ… `src/models/helper_bench_test.go` - MIME benchmarks -- โœ… `PERFORMANCE_OPTIMIZATIONS.md` - Detailed optimization guide -- โœ… `PERFORMANCE_SUMMARY.md` - This summary document - -### Modified Files: -- โœ… `src/models/helper.go` - Optimized MIME normalization and prompt building -- โœ… `README.md` - Added performance section - ---- - -## ๐Ÿงช Testing Status - -**All tests pass:** -```bash -โœ… src/cache - 2 tests passing -โœ… src/models - 13 tests passing -โœ… src/memory/engine - 1 benchmark test -โœ… All packages - 24/24 packages passing -``` - -**Benchmarks run successfully:** -```bash -โœ… BenchmarkNormalizeMIME - 33M ops/sec -โœ… BenchmarkCombinePromptWithFiles - 4.2M ops/sec (small) -โœ… BenchmarkLRUCache - 5.9M ops/sec (set), 7M ops/sec (get) -``` - ---- - -## ๐Ÿ“Š Performance Comparison - -### Before vs After (Estimated) - -| Operation | Before | After | Improvement | -|-----------|--------|-------|-------------| -| MIME normalization | ~500 ns | 36 ns | **13x faster** | -| Prompt building (small) | ~600 ns | 282 ns | **2.1x faster** | -| Prompt building (large) | ~5000 ns | 2468 ns | **2x faster** | -| Allocations (MIME) | 3-4/op | 1/op | **70-75% reduction** | -| Allocations (prompt) | 8-12/op | 5/op | **40-60% reduction** | - ---- - -## ๐Ÿ’ก How to Use - -### No Changes Required! - -All optimizations are **automatically active**. Your existing code will run faster without modifications. - -### Optional: Future LLM Caching - -When you're ready to add LLM response caching: - -```go -import "github.com/Protocol-Lattice/go-agent/src/cache" - -// Create cache -llmCache := cache.NewLRUCache(1000, 5*time.Minute) - -// Before LLM call, check cache -cacheKey := cache.HashKey(prompt) -if cached, ok := llmCache.Get(cacheKey); ok { - return cached, nil -} - -// After LLM call, store in cache -llmCache.Set(cacheKey, response) -``` - -### Optional: Concurrent Operations - -Use the concurrent utilities for parallel processing: - -```go -import "github.com/Protocol-Lattice/go-agent/src/concurrent" - -// Process items in parallel -results, err := concurrent.ParallelMap(ctx, items, func(item Item) (Result, error) { - return processItem(item) -}, 10) // max 10 concurrent -``` - ---- - -## ๐ŸŽ“ Key Learnings - -### Optimization Techniques Applied: - -1. **Pre-computation** - Calculate once, use many times (lookup tables) -2. **Caching** - Store expensive computations (LRU cache) -3. **Pre-allocation** - Allocate memory once (buffer.Grow) -4. **Lock optimization** - Use RWMutex for read-heavy loads -5. **String builders** - More efficient than buffer for strings -6. **Benchmarking** - Measure everything before and after - -### Performance Principles: - -- โœ… **Measure first** - Benchmarks drove all decisions -- โœ… **Optimize hot paths** - Focus on frequently called code -- โœ… **Reduce allocations** - Memory allocations are expensive -- โœ… **Cache intelligently** - Balance memory vs speed -- โœ… **Test thoroughly** - All optimizations have tests - ---- - -## ๐Ÿš€ Production Readiness - -**This code is production-ready:** - -- โœ… **No breaking changes** - 100% backwards compatible -- โœ… **Comprehensive tests** - All existing tests pass -- โœ… **Thread-safe** - Proper locking everywhere -- โœ… **Memory-safe** - No leaks or unbounded growth -- โœ… **Well-documented** - Inline comments explain why -- โœ… **Benchmarked** - Performance verified - ---- - -## ๐Ÿ“ˆ Future Optimizations - -**Potential next steps:** - -1. **LLM response caching** - Use the LRU cache for model calls -2. **Parallel memory operations** - Leverage concurrent utilities -3. **Request batching** - Process multiple requests together -4. **HTTP connection pooling** - Reuse connections to APIs -5. **Streaming responses** - Start processing before completion - ---- - -## ๐ŸŽ‰ Bottom Line - -**go-agent is now significantly faster:** - -- โœ… **10-50x faster** MIME type handling -- โœ… **40-60% fewer** memory allocations -- โœ… **64% faster** for non-tool queries (toolOrchestrator optimization) -- โœ… **2.8x faster** average response time for common queries -- โœ… **Production-grade** caching infrastructure -- โœ… **Ready for scale** with concurrent utilities -- โœ… **100% tested** with comprehensive benchmarks - -**All optimizations are live and ready to use!** ๐Ÿš€ diff --git a/TOOL_ORCHESTRATOR_OPTIMIZATION.md b/TOOL_ORCHESTRATOR_OPTIMIZATION.md deleted file mode 100644 index 54cae0a..0000000 --- a/TOOL_ORCHESTRATOR_OPTIMIZATION.md +++ /dev/null @@ -1,269 +0,0 @@ -# Tool Orchestrator Performance Optimization - -## Problem - -Even with CodeMode optimized, **`CallTool` was taking 1-3 seconds to execute** because the `toolOrchestrator` was making **expensive LLM calls** for EVERY request that CodeMode didn't handle. - -### The Bottleneck - -```go -// In agent.Generate() - -// 1. CodeMode tries to handle (fast - optimized) -if a.CodeMode != nil { - if handled, output, err := a.CodeMode.CallTool(ctx, userInput); handled { - return output, nil - } -} - -// 2. If CodeMode doesn't handle it, toolOrchestrator makes LLM call (SLOW!) -if handled, output, err := a.toolOrchestrator(ctx, sessionID, userInput); handled { - return output, nil -} -``` - -**The issue:** Even for simple questions like "What is X?", `toolOrchestrator` would: -1. Collect all tool specs (fast) -2. Build a massive prompt with tool descriptions -3. **Make an LLM call asking "should I use a tool?"** โฑ๏ธ **1-3 seconds** -4. Parse the JSON response -5. Return "no tool needed" - -This happened on **every single request** where CodeMode didn't apply! - ---- - -## Solution: Fast-Path Heuristics โšก - -Added **fast heuristic checks** to skip the LLM call when we can quickly determine no tool is needed: - -### Before -```go -func (a *Agent) toolOrchestrator(ctx, sessionID, userInput) { - toolList := a.ToolSpecs() - if len(toolList) == 0 { - return false, "", nil - } - - // Build prompt... - - // EXPENSIVE LLM CALL - EVERY TIME! - raw, err := a.model.Generate(ctx, choicePrompt) - // ... process response -} -``` - -### After -```go -func (a *Agent) toolOrchestrator(ctx, sessionID, userInput) { - // FAST PATH: Skip LLM call for obvious non-tool queries - lowerInput := strings.ToLower(strings.TrimSpace(userInput)) - - if !a.likelyNeedsToolCall(lowerInput) { - return false, "", nil // Exit in microseconds! - } - - // Only make LLM call if heuristics suggest a tool might be needed - toolList := a.ToolSpecs() - // ... rest of logic -} -``` - ---- - -## Heuristic Logic - -The `likelyNeedsToolCall()` function uses **fast pattern matching** to filter out obvious non-tool requests: - -### โŒ Skip Tool Orchestration For: - -1. **Questions without action words** - - "What is X?" โ†’ Skip (no tool keywords) - - "Why does Y happen?" โ†’ Skip - - "Explain Z" โ†’ Skip - -2. **Very short input** (< 10 characters) - - "Hi" โ†’ Skip - - "Thanks" โ†’ Skip - -3. **JSON input** (handled by direct tool call path) - - `{"tool_name": ...}` โ†’ Skip (handled elsewhere) - -### โœ… Allow Tool Orchestration For: - -1. **Inputs with tool keywords** - - "search for X" โ†’ Allow (has "search") - - "find files" โ†’ Allow (has "find") - - "create a report" โ†’ Allow (has "create") - -2. **Questions with action words** - - "What files are in X?" โ†’ Allow (has "files") - - "How do I search for Y?" โ†’ Allow (has "search") - -3. **When uncertain** โ†’ Allow (better safe than sorry) - ---- - -## Performance Impact - -### Before Optimization -``` -User: "What is pgvector?" - โ†“ -CodeMode.CallTool โ†’ Not handled (50ms) - โ†“ -toolOrchestrator โ†’ LLM call (1500ms) โ†’ "no tool" - โ†“ -LLM completion โ†’ Answer (800ms) - โ†“ -TOTAL: ~2350ms -``` - -### After Optimization -``` -User: "What is pgvector?" - โ†“ -CodeMode.CallTool โ†’ Not handled (50ms) - โ†“ -toolOrchestrator โ†’ Fast heuristic (0.1ms) โ†’ "no tool" - โ†“ -LLM completion โ†’ Answer (800ms) - โ†“ -TOTAL: ~850ms (64% faster!) -``` - -### For Tool Requests -``` -User: "search for database files" - โ†“ -CodeMode.CallTool โ†’ Not handled (50ms) - โ†“ -toolOrchestrator โ†’ Heuristic (0.1ms) โ†’ "likely needs tool" - โ†“ -toolOrchestrator โ†’ LLM call (1500ms) โ†’ "use search tool" - โ†“ -Execute tool โ†’ Result (200ms) - โ†“ -TOTAL: ~1750ms (same as before, no regression) -``` - ---- - -## Code Changes - -### 1. Added Fast-Path in `toolOrchestrator` - -```go -// FAST PATH: Skip LLM call for obvious non-tool queries -// This saves 1-3 seconds per request! -lowerInput := strings.ToLower(strings.TrimSpace(userInput)) - -// Skip if input looks like a natural question/statement -if !a.likelyNeedsToolCall(lowerInput) { - return false, "", nil -} -``` - -### 2. Added Heuristic Function - -```go -func (a *Agent) likelyNeedsToolCall(lowerInput string) bool { - // Check for tool action keywords - toolKeywords := []string{ - "search", "find", "lookup", "query", "fetch", - "get", "list", "show", "display", - "read", "load", "retrieve", - "write", "save", "create", "update", "delete", - "call", "execute", "run", "invoke", - } - - // Check for question words (usually NOT tool calls) - questionWords := []string{ - "what", "why", "how", "when", "where", "who", - "explain", "tell me", "describe", "define", - } - - // Logic to determine likelihood... -} -``` - ---- - -## Testing - -### Compile Test -```bash -โœ… go build ./... - SUCCESS -``` - -### Benchmark Comparison - -**Before:** -- Simple question: ~2350ms -- Tool request: ~1750ms -- Average: ~2050ms - -**After:** -- Simple question: ~850ms (2.8x faster) -- Tool request: ~1750ms (no change) -- Average: ~1300ms (1.6x faster) - ---- - -## Key Benefits - -โœ… **64% faster** for non-tool queries -โœ… **No regression** for actual tool requests -โœ… **Zero breaking changes** - backwards compatible -โœ… **Extensible** - easy to add more heuristics -โœ… **Safe** - defaults to checking when uncertain - ---- - -## Configuration - -No configuration needed! The optimization is **automatically active**. - -### Future Tuning - -If you want to customize the heuristics: - -```go -// In agent.go, modify likelyNeedsToolCall() - -// Add more tool keywords -toolKeywords := []string{ - "search", "find", - "analyze", "process", // Your custom keywords -} - -// Add more question patterns -questionWords := []string{ - "what", "why", - "summarize", "overview", // Your custom patterns -} -``` - ---- - -## About the Compilation Error - -The compilation error you saw (`expected ';', found ':='`) is **separate** from this optimization. That's a CodeMode code generation issue where it's producing invalid Go syntax. - -To debug that: -1. Check what code CodeMode is generating -2. Look for missing `var` keywords before `:=` -3. Examine the template/prompt that generates the Go code - -This optimization **fixes the performance issue** regardless of that error. - ---- - -## Summary - -**Problem:** `toolOrchestrator` made slow LLM calls for every request -**Solution:** Added fast heuristics to skip unnecessary LLM calls -**Result:** **64% faster** for normal queries, no regression for tool requests -**Status:** โœ… Production-ready, backwards compatible, fully tested - -๐Ÿš€ **go-agent is now significantly faster for all non-tool queries!** diff --git a/agent_security_test.go b/agent_security_test.go index a8283b5..abc1737 100644 --- a/agent_security_test.go +++ b/agent_security_test.go @@ -24,6 +24,14 @@ func (m *mockModel) GenerateWithFiles(ctx context.Context, prompt string, files return "mock response", nil } +func (m *mockModel) GenerateStream(ctx context.Context, prompt string) (<-chan models.StreamChunk, error) { + m.lastPrompt = prompt + ch := make(chan models.StreamChunk, 1) + ch <- models.StreamChunk{Delta: "mock response", FullText: "mock response", Done: true} + close(ch) + return ch, nil +} + func TestPromptInjectionPrevention(t *testing.T) { // Setup s := store.NewInMemoryStore() diff --git a/agent_stream.go b/agent_stream.go new file mode 100644 index 0000000..eb835b0 --- /dev/null +++ b/agent_stream.go @@ -0,0 +1,108 @@ +package agent + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/Protocol-Lattice/go-agent/src/models" +) + +// GenerateStream provides a streaming interface for the agent's generation process. +// It follows the same logic as Generate but returns a channel of chunks. +func (a *Agent) GenerateStream(ctx context.Context, sessionID, userInput string) (<-chan models.StreamChunk, error) { + trimmed := strings.TrimSpace(userInput) + if trimmed == "" { + return nil, errors.New("user input is empty") + } + + // Helper to wrap immediate result in a stream + immediateStream := func(val any, err error) (<-chan models.StreamChunk, error) { + ch := make(chan models.StreamChunk, 1) + if err != nil { + ch <- models.StreamChunk{Err: err, Done: true} + } else { + str := fmt.Sprint(val) + ch <- models.StreamChunk{Delta: str, FullText: str, Done: true} + } + close(ch) + return ch, nil + } + + // 0. DIRECT TOOL INVOCATION + if toolName, args, ok := a.detectDirectToolCall(trimmed); ok { + result, err := a.executeTool(ctx, sessionID, toolName, args) + return immediateStream(result, err) + } + + // 1. SUBAGENT COMMANDS + if handled, out, meta, err := a.handleCommand(ctx, sessionID, userInput); handled { + if err != nil { + return nil, err + } + a.storeMemory(sessionID, "subagent", out, meta) + return immediateStream(out, nil) + } + + // 2. CODEMODE + if a.CodeMode != nil { + if handled, output, err := a.CodeMode.CallTool(ctx, userInput); handled { + return immediateStream(output, err) + } + } + + // 3. Chain Orchestrator + if handled, output, err := a.codeChainOrchestrator(ctx, sessionID, userInput); handled { + return immediateStream(output, err) + } + + // 4. TOOL ORCHESTRATOR + if handled, output, err := a.toolOrchestrator(ctx, sessionID, userInput); handled { + return immediateStream(output, err) + } + + // 5. STORE USER MEMORY + a.storeMemory(sessionID, "user", userInput, nil) + + // If it looked like a tool call but wasn't handled, return empty + if a.userLooksLikeToolCall(trimmed) { + return immediateStream("", nil) + } + + // 6. LLM COMPLETION (Streaming) + prompt, err := a.buildPrompt(ctx, sessionID, userInput) + if err != nil { + return nil, err + } + + // Note: Currently GenerateStream does not support file attachments for streaming. + // We proceed with text-only streaming. + + stream, err := a.model.GenerateStream(ctx, prompt) + if err != nil { + return nil, err + } + + // Wrap the stream to intercept and store memory + outCh := make(chan models.StreamChunk) + go func() { + defer close(outCh) + var full strings.Builder + for chunk := range stream { + if chunk.Err != nil { + outCh <- chunk + return + } + if chunk.Delta != "" { + full.WriteString(chunk.Delta) + } + outCh <- chunk + } + // Store memory after completion + finalText := full.String() + a.storeMemory(sessionID, "assistant", finalText, nil) + }() + + return outCh, nil +} diff --git a/agent_test.go b/agent_test.go index 435ff07..ad60e80 100644 --- a/agent_test.go +++ b/agent_test.go @@ -36,6 +36,19 @@ func (m *stubModel) Generate(ctx context.Context, prompt string) (any, error) { return m.response + " | " + prompt, nil } +func (m *stubModel) GenerateStream(ctx context.Context, prompt string) (<-chan models.StreamChunk, error) { + ch := make(chan models.StreamChunk, 1) + val, err := m.Generate(ctx, prompt) + if err != nil { + ch <- models.StreamChunk{Err: err, Done: true} + } else { + str := fmt.Sprint(val) + ch <- models.StreamChunk{Delta: str, FullText: str, Done: true} + } + close(ch) + return ch, nil +} + type fileEchoModel struct { response string } @@ -48,6 +61,13 @@ func (m *fileEchoModel) GenerateWithFiles(ctx context.Context, prompt string, fi return m.response, nil } +func (m *fileEchoModel) GenerateStream(ctx context.Context, prompt string) (<-chan models.StreamChunk, error) { + ch := make(chan models.StreamChunk, 1) + ch <- models.StreamChunk{Delta: m.response, FullText: m.response, Done: true} + close(ch) + return ch, nil +} + type dynamicStubModel struct { responses map[string]string err error @@ -73,6 +93,19 @@ func (m *dynamicStubModel) Generate(ctx context.Context, prompt string) (any, er return "default model response for: " + prompt, nil } +func (m *dynamicStubModel) GenerateStream(ctx context.Context, prompt string) (<-chan models.StreamChunk, error) { + ch := make(chan models.StreamChunk, 1) + val, err := m.Generate(ctx, prompt) + if err != nil { + ch <- models.StreamChunk{Err: err, Done: true} + } else { + str := fmt.Sprint(val) + ch <- models.StreamChunk{Delta: str, FullText: str, Done: true} + } + close(ch) + return ch, nil +} + type stubUTCPClient struct { callCount int lastToolName string diff --git a/cmd/example/agent_as_tool/main.go b/cmd/example/agent_as_tool/main.go index f9547b5..6c74452 100644 --- a/cmd/example/agent_as_tool/main.go +++ b/cmd/example/agent_as_tool/main.go @@ -37,6 +37,19 @@ func (m *MockModel) GenerateWithFiles(ctx context.Context, prompt string, files return m.Generate(ctx, prompt) } +func (m *MockModel) GenerateStream(ctx context.Context, prompt string) (<-chan models.StreamChunk, error) { + ch := make(chan models.StreamChunk, 1) + val, err := m.Generate(ctx, prompt) + if err != nil { + ch <- models.StreamChunk{Err: err, Done: true} + } else { + str := fmt.Sprint(val) + ch <- models.StreamChunk{Delta: str, FullText: str, Done: true} + } + close(ch) + return ch, nil +} + func main() { ctx := context.Background() diff --git a/cmd/example/checkpoint/main.go b/cmd/example/checkpoint/main.go index 071a43e..7b2e2ab 100644 --- a/cmd/example/checkpoint/main.go +++ b/cmd/example/checkpoint/main.go @@ -24,6 +24,13 @@ func (m *mockModel) GenerateWithFiles(ctx context.Context, prompt string, files return m.Generate(ctx, prompt) } +func (m *mockModel) GenerateStream(ctx context.Context, prompt string) (<-chan models.StreamChunk, error) { + ch := make(chan models.StreamChunk, 1) + ch <- models.StreamChunk{Delta: "I remember what you said!", FullText: "I remember what you said!", Done: true} + close(ch) + return ch, nil +} + func main() { ctx := context.Background() diff --git a/cmd/example/codemode/main.go b/cmd/example/codemode/main.go index d7d4a84..f864756 100644 --- a/cmd/example/codemode/main.go +++ b/cmd/example/codemode/main.go @@ -29,6 +29,19 @@ func (m *DemoModel) GenerateWithFiles(ctx context.Context, prompt string, files return m.Generate(ctx, prompt) } +func (m *DemoModel) GenerateStream(ctx context.Context, prompt string) (<-chan models.StreamChunk, error) { + ch := make(chan models.StreamChunk, 1) + val, err := m.Generate(ctx, prompt) + if err != nil { + ch <- models.StreamChunk{Err: err, Done: true} + } else { + str := fmt.Sprint(val) + ch <- models.StreamChunk{Delta: str, FullText: str, Done: true} + } + close(ch) + return ch, nil +} + func main() { ctx := context.Background() diff --git a/cmd/example/codemode_utcp_workflow/main.go b/cmd/example/codemode_utcp_workflow/main.go index 53409c7..d94728f 100644 --- a/cmd/example/codemode_utcp_workflow/main.go +++ b/cmd/example/codemode_utcp_workflow/main.go @@ -27,6 +27,19 @@ func (m *SimpleModel) GenerateWithFiles(ctx context.Context, prompt string, file return m.Generate(ctx, prompt) } +func (m *SimpleModel) GenerateStream(ctx context.Context, prompt string) (<-chan models.StreamChunk, error) { + ch := make(chan models.StreamChunk, 1) + val, err := m.Generate(ctx, prompt) + if err != nil { + ch <- models.StreamChunk{Err: err, Done: true} + } else { + str := fmt.Sprint(val) + ch <- models.StreamChunk{Delta: str, FullText: str, Done: true} + } + close(ch) + return ch, nil +} + func main() { ctx := context.Background() diff --git a/cmd/example/composability/main.go b/cmd/example/composability/main.go index dc579e7..d35e2f2 100644 --- a/cmd/example/composability/main.go +++ b/cmd/example/composability/main.go @@ -32,6 +32,19 @@ func (m *DummyModel) GenerateWithFiles(ctx context.Context, prompt string, files return m.Generate(ctx, prompt) } +func (m *DummyModel) GenerateStream(ctx context.Context, prompt string) (<-chan models.StreamChunk, error) { + ch := make(chan models.StreamChunk, 1) + val, err := m.Generate(ctx, prompt) + if err != nil { + ch <- models.StreamChunk{Err: err, Done: true} + } else { + str := fmt.Sprint(val) + ch <- models.StreamChunk{Delta: str, FullText: str, Done: true} + } + close(ch) + return ch, nil +} + func main() { ctx := context.Background() diff --git a/src/adk/modules/helpers_test.go b/src/adk/modules/helpers_test.go index f36bab7..1f65510 100644 --- a/src/adk/modules/helpers_test.go +++ b/src/adk/modules/helpers_test.go @@ -27,6 +27,14 @@ func (s *stubAgent) GenerateWithFiles(context.Context, string, []models.File) (a return "files", nil } +func (s *stubAgent) GenerateStream(context.Context, string) (<-chan models.StreamChunk, error) { + atomic.AddInt32(&s.called, 1) + ch := make(chan models.StreamChunk, 1) + ch <- models.StreamChunk{Delta: "ok", FullText: "ok", Done: true} + close(ch) + return ch, nil +} + type stubTool struct{ name string } func (s stubTool) Spec() agent.ToolSpec { return agent.ToolSpec{Name: s.name} } diff --git a/src/models/anthropics.go b/src/models/anthropics.go index 08cdefe..9b360e8 100644 --- a/src/models/anthropics.go +++ b/src/models/anthropics.go @@ -122,6 +122,46 @@ func (a *AnthropicLLM) GenerateWithFiles(ctx context.Context, prompt string, fil return b.String(), nil } +// GenerateStream uses Anthropic's streaming messages API. +func (a *AnthropicLLM) GenerateStream(ctx context.Context, prompt string) (<-chan StreamChunk, error) { + fullPrompt := prompt + if a.PromptPrefix != "" { + fullPrompt = fmt.Sprintf("%s\n\n%s", a.PromptPrefix, prompt) + } + + stream := a.Client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{ + Model: anthropic.Model(a.Model), + MaxTokens: int64(a.MaxTokens), + Messages: []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock(fullPrompt)), + }, + }) + + ch := make(chan StreamChunk, 16) + go func() { + defer close(ch) + var sb strings.Builder + for stream.Next() { + evt := stream.Current() + switch delta := evt.AsAny().(type) { + case anthropic.ContentBlockDeltaEvent: + text := delta.Delta.Text + if text != "" { + sb.WriteString(text) + ch <- StreamChunk{Delta: text} + } + } + } + if err := stream.Err(); err != nil { + ch <- StreamChunk{Done: true, FullText: sb.String(), Err: err} + return + } + ch <- StreamChunk{Done: true, FullText: sb.String()} + }() + + return ch, nil +} + // sanitizeForAnthropic filters MIME types to what Anthropic supports func sanitizeForAnthropic(mt string) string { mt = strings.ToLower(strings.TrimSpace(mt)) diff --git a/src/models/cached.go b/src/models/cached.go index 9dc5c0c..72ebcda 100644 --- a/src/models/cached.go +++ b/src/models/cached.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "fmt" "os" "strconv" "time" @@ -110,6 +111,43 @@ func (c *CachedLLM) GenerateWithFiles(ctx context.Context, prompt string, files return res, nil } +// GenerateStream passes through to the underlying agent's streaming. +// If the prompt is already cached, it returns a single-chunk stream from cache. +// Otherwise, it streams from the underlying agent and caches the full result when done. +func (c *CachedLLM) GenerateStream(ctx context.Context, prompt string) (<-chan StreamChunk, error) { + key := cache.HashKey(prompt) + if val, ok := c.Cache.Get(key); ok { + ch := make(chan StreamChunk, 1) + go func() { + defer close(ch) + text := fmt.Sprint(val) + ch <- StreamChunk{Delta: text, Done: true, FullText: text} + }() + return ch, nil + } + + innerCh, err := c.Agent.GenerateStream(ctx, prompt) + if err != nil { + return nil, err + } + + ch := make(chan StreamChunk, 16) + go func() { + defer close(ch) + for chunk := range innerCh { + ch <- chunk + if chunk.Done { + if chunk.FullText != "" && chunk.Err == nil { + c.Cache.Set(key, chunk.FullText) + c.save() + } + } + } + }() + + return ch, nil +} + // TryCreateCachedLLM checks env vars and wraps the agent if caching is enabled. func TryCreateCachedLLM(agent Agent) Agent { sizeStr := os.Getenv("AGENT_LLM_CACHE_SIZE") diff --git a/src/models/cached_test.go b/src/models/cached_test.go index 00ae5f6..8db9666 100644 --- a/src/models/cached_test.go +++ b/src/models/cached_test.go @@ -21,6 +21,14 @@ func (m *MockAgent) GenerateWithFiles(ctx context.Context, prompt string, files return "mock response with files", nil } +func (m *MockAgent) GenerateStream(ctx context.Context, prompt string) (<-chan StreamChunk, error) { + atomic.AddInt32(&m.CallCount, 1) + ch := make(chan StreamChunk, 1) + ch <- StreamChunk{Delta: "mock stream response", FullText: "mock stream response", Done: true} + close(ch) + return ch, nil +} + func TestCachedLLM_Generate(t *testing.T) { mock := &MockAgent{} cached := NewCachedLLM(mock, 10, time.Minute, "") diff --git a/src/models/dummy.go b/src/models/dummy.go index 391ae08..2b8ecfb 100644 --- a/src/models/dummy.go +++ b/src/models/dummy.go @@ -41,4 +41,27 @@ func (d *DummyLLM) GenerateWithFiles(ctx context.Context, prompt string, files [ return fmt.Sprintf("%s %s", d.Prefix, combined), nil } +// GenerateStream simulates streaming by splitting the response into word-level chunks. +func (d *DummyLLM) GenerateStream(_ context.Context, prompt string) (<-chan StreamChunk, error) { + result, _ := d.Generate(context.Background(), prompt) + text := fmt.Sprint(result) + + ch := make(chan StreamChunk, 16) + go func() { + defer close(ch) + words := strings.Fields(text) + var sb strings.Builder + for i, word := range words { + if i > 0 { + word = " " + word + } + sb.WriteString(word) + ch <- StreamChunk{Delta: word} + } + ch <- StreamChunk{Done: true, FullText: sb.String()} + }() + + return ch, nil +} + var _ Agent = (*DummyLLM)(nil) diff --git a/src/models/gemini.go b/src/models/gemini.go index d77cc90..ae11c88 100644 --- a/src/models/gemini.go +++ b/src/models/gemini.go @@ -48,6 +48,48 @@ func (g *GeminiLLM) Generate(ctx context.Context, prompt string) (any, error) { return resp.Candidates[0].Content.Parts[0], nil } +// GenerateStream uses Gemini's streaming API to yield tokens incrementally. +func (g *GeminiLLM) GenerateStream(ctx context.Context, prompt string) (<-chan StreamChunk, error) { + model := g.Client.GenerativeModel(g.Model) + full := prompt + if g.PromptPrefix != "" { + full = g.PromptPrefix + "\n\n" + prompt + } + + iter := model.GenerateContentStream(ctx, genai.Text(full)) + + ch := make(chan StreamChunk, 16) + go func() { + defer close(ch) + var sb strings.Builder + for { + resp, err := iter.Next() + if err != nil { + // io.EOF signals normal end of stream + if err.Error() == "no more items in iterator" || err == context.Canceled { + ch <- StreamChunk{Done: true, FullText: sb.String()} + return + } + // Check for iterator exhaustion via the google iterator sentinel + ch <- StreamChunk{Done: true, FullText: sb.String(), Err: err} + return + } + if resp == nil || len(resp.Candidates) == 0 { + continue + } + cand := resp.Candidates[0] + if cand.Content == nil || len(cand.Content.Parts) == 0 { + continue + } + delta := fmt.Sprint(cand.Content.Parts[0]) + sb.WriteString(delta) + ch <- StreamChunk{Delta: delta} + } + }() + + return ch, nil +} + // NEW: pass images/videos as parts so Gemini can read them. // Falls back to text-only if there are no binary attachments. // gemini.go (inside package models) diff --git a/src/models/interface.go b/src/models/interface.go index 57f07a3..24a6bc8 100644 --- a/src/models/interface.go +++ b/src/models/interface.go @@ -10,7 +10,23 @@ type File struct { Data []byte } +// StreamChunk represents a single piece of a streaming LLM response. +// When Done is true, the stream is complete and FullText holds the aggregated output. +// When Err is non-nil, the stream encountered an error. +type StreamChunk struct { + Delta string // incremental text token + Done bool // true on the final chunk + FullText string // aggregated text (populated only on the final chunk) + Err error // non-nil if the stream encountered a fatal error +} + type Agent interface { Generate(context.Context, string) (any, error) GenerateWithFiles(context.Context, string, []File) (any, error) + + // GenerateStream returns a channel that yields incremental text chunks. + // The final chunk has Done=true and FullText set to the complete response. + // If the provider doesn't support streaming natively, it falls back to + // a single-chunk response wrapping Generate. + GenerateStream(ctx context.Context, prompt string) (<-chan StreamChunk, error) } diff --git a/src/models/ollama.go b/src/models/ollama.go index 4be57ca..925e7e6 100644 --- a/src/models/ollama.go +++ b/src/models/ollama.go @@ -146,6 +146,39 @@ func (o *OllamaLLM) GenerateWithFiles(ctx context.Context, prompt string, files }, nil } +// GenerateStream leverages Ollama's native callback-based streaming. +func (o *OllamaLLM) GenerateStream(ctx context.Context, prompt string) (<-chan StreamChunk, error) { + fullPrompt := prompt + if o.PromptPrefix != "" { + fullPrompt = fmt.Sprintf("%s\n\n%s", o.PromptPrefix, prompt) + } + + req := &ollama.GenerateRequest{ + Model: o.Model, + Prompt: fullPrompt, + } + + ch := make(chan StreamChunk, 16) + go func() { + defer close(ch) + var sb strings.Builder + err := o.Client.Generate(ctx, req, func(gr ollama.GenerateResponse) error { + if gr.Response != "" { + sb.WriteString(gr.Response) + ch <- StreamChunk{Delta: gr.Response} + } + return nil + }) + if err != nil { + ch <- StreamChunk{Done: true, FullText: sb.String(), Err: err} + return + } + ch <- StreamChunk{Done: true, FullText: sb.String()} + }() + + return ch, nil +} + // WebSearch queries the Ollama Web Search API and returns top results. func (o *OllamaLLM) WebSearch(ctx context.Context, query string, limit int) ([]map[string]string, error) { endpoint := fmt.Sprintf("%s/api/web_search", strings.TrimRight(o.host, "/")) diff --git a/src/models/openai.go b/src/models/openai.go index a3da627..8f52bc3 100644 --- a/src/models/openai.go +++ b/src/models/openai.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "errors" "fmt" + "io" "os" "strings" @@ -48,6 +49,53 @@ func (o *OpenAILLM) Generate(ctx context.Context, prompt string) (any, error) { return resp.Choices[0].Message.Content, nil } +// GenerateStream uses OpenAI's streaming chat completion API. +func (o *OpenAILLM) GenerateStream(ctx context.Context, prompt string) (<-chan StreamChunk, error) { + fullPrompt := prompt + if o.PromptPrefix != "" { + fullPrompt = o.PromptPrefix + "\n" + prompt + } + + stream, err := o.Client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{ + Model: o.Model, + Messages: []openai.ChatCompletionMessage{{ + Role: openai.ChatMessageRoleUser, + Content: fullPrompt, + }}, + Stream: true, + }) + if err != nil { + return nil, err + } + + ch := make(chan StreamChunk, 16) + go func() { + defer close(ch) + defer stream.Close() + var sb strings.Builder + for { + resp, err := stream.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + ch <- StreamChunk{Done: true, FullText: sb.String()} + return + } + ch <- StreamChunk{Done: true, FullText: sb.String(), Err: err} + return + } + if len(resp.Choices) > 0 { + delta := resp.Choices[0].Delta.Content + if delta != "" { + sb.WriteString(delta) + ch <- StreamChunk{Delta: delta} + } + } + } + }() + + return ch, nil +} + // getOpenAIMimeType converts normalized MIME types to OpenAI's expected format func getOpenAIMimeType(mt string) string { mt = strings.ToLower(strings.TrimSpace(mt)) diff --git a/src/subagents/researcher_test.go b/src/subagents/researcher_test.go index ee9a081..f846707 100644 --- a/src/subagents/researcher_test.go +++ b/src/subagents/researcher_test.go @@ -29,6 +29,13 @@ func (f *fakeModel) Generate(ctx context.Context, prompt string) (any, error) { return "ok", nil } +func (f *fakeModel) GenerateStream(ctx context.Context, prompt string) (<-chan models.StreamChunk, error) { + ch := make(chan models.StreamChunk, 1) + ch <- models.StreamChunk{Delta: "ok", FullText: "ok", Done: true} + close(ch) + return ch, nil +} + func TestResearcherRunIncludesPersonaAndTask(t *testing.T) { fm := &fakeModel{response: "result"} researcher := NewResearcher(fm)