diff --git a/pkg/api/schemas/chat_stream.go b/pkg/api/schemas/chat_stream.go index 983d2242..ed43b99d 100644 --- a/pkg/api/schemas/chat_stream.go +++ b/pkg/api/schemas/chat_stream.go @@ -24,6 +24,20 @@ var ( UnknownError ErrorCode = "unknown_error" ) +type StreamingCacheEntry struct { + Key string + Query string + ResponseChunks []string + Complete bool +} + +type StreamingCacheEntryChunk struct { + Key string + Index int + Content ChatStreamChunk + Complete bool +} + type StreamRequestID = string // ChatStreamRequest defines a message that requests a new streaming chat diff --git a/pkg/cache/memory_cache.go b/pkg/cache/memory_cache.go new file mode 100644 index 00000000..3d2045ed --- /dev/null +++ b/pkg/cache/memory_cache.go @@ -0,0 +1,27 @@ +package cache + +import "sync" + +type MemoryCache struct { + cache map[string]interface{} + lock sync.RWMutex +} + +func NewMemoryCache() *MemoryCache { + return &MemoryCache{ + cache: make(map[string]interface{}), + } +} + +func (m *MemoryCache) Get(key string) (interface{}, bool) { + m.lock.RLock() + defer m.lock.RUnlock() + val, found := m.cache[key] + return val, found +} + +func (m *MemoryCache) Set(key string, value interface{}) { + m.lock.Lock() + defer m.lock.Unlock() + m.cache[key] = value +} diff --git a/pkg/routers/router.go b/pkg/routers/router.go index a4128f7d..00ed4f90 100644 --- a/pkg/routers/router.go +++ b/pkg/routers/router.go @@ -3,7 +3,10 @@ package routers import ( "context" "errors" + "fmt" + "log" + "github.com/EinStack/glide/pkg/cache" "github.com/EinStack/glide/pkg/routers/retry" "go.uber.org/zap" @@ -33,6 +36,7 @@ type LangRouter struct { retry *retry.ExpRetry tel *telemetry.Telemetry logger *zap.Logger + cache *cache.MemoryCache } func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter, error) { @@ -56,6 +60,7 @@ func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter chatStreamRouting: chatStreamRouting, tel: tel, logger: tel.L().With(zap.String("routerID", cfg.ID)), + cache: cache.NewMemoryCache(), } return router, err @@ -70,6 +75,17 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem return nil, ErrNoModels } + // Generate cache key + cacheKey := req.Message.Content + if cachedResponse, found := r.cache.Get(cacheKey); found { + log.Println("found cached response and returning: ", cachedResponse) + if response, ok := cachedResponse.(*schemas.ChatResponse); ok { + return response, nil + } else { + log.Println("Failed to cast cached response to ChatResponse") + } + } + retryIterator := r.retry.Iterator() for retryIterator.HasNext() { @@ -101,17 +117,17 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem zap.String("provider", langModel.Provider()), zap.Error(err), ) - continue } resp.RouterID = r.routerID + // Store response in cache + r.cache.Set(cacheKey, resp) + return resp, nil } - // no providers were available to handle the request, - // so we have to wait a bit with a hope there is some available next time r.logger.Warn("No healthy model found to serve chat request, wait and retry") err := retryIterator.WaitNext(ctx) @@ -121,7 +137,6 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem } } - // if we reach this part, then we are in trouble r.logger.Error("No model was available to handle chat request") return nil, ErrNoModelAvailable @@ -141,10 +156,43 @@ func (r *LangRouter) ChatStream( req.Metadata, &schemas.ErrorReason, ) - return } + cacheKey := req.Message.Content + if streamingCacheEntry, found := r.cache.Get(cacheKey); found { + if entry, ok := streamingCacheEntry.(*schemas.StreamingCacheEntry); ok { + for _, chunkKey := range entry.ResponseChunks { + if cachedChunk, found := r.cache.Get(chunkKey); found { + if chunk, ok := cachedChunk.(*schemas.ChatStreamChunk); ok { + respC <- schemas.NewChatStreamChunk( + req.ID, + r.routerID, + req.Metadata, + chunk, + ) + } else { + log.Println("Failed to cast cached chunk to ChatStreamChunk") + } + } + } + + if entry.Complete { + return + } + } else { + log.Println("Failed to cast cached entry to StreamingCacheEntry") + } + } else { + streamingCacheEntry := &schemas.StreamingCacheEntry{ + Key: cacheKey, + Query: req.Message.Content, + ResponseChunks: []string{}, + Complete: false, + } + r.cache.Set(cacheKey, streamingCacheEntry) + } + retryIterator := r.retry.Iterator() for retryIterator.HasNext() { @@ -172,6 +220,7 @@ func (r *LangRouter) ChatStream( continue } + buffer := []schemas.ChatStreamChunk{} for chunkResult := range modelRespC { err = chunkResult.Error() if err != nil { @@ -182,9 +231,6 @@ func (r *LangRouter) ChatStream( zap.Error(err), ) - // It's challenging to hide an error in case of streaming chat as consumer apps - // may have already used all chunks we streamed this far (e.g. showed them to their users like OpenAI UI does), - // so we cannot easily restart that process from scratch respC <- schemas.NewChatStreamError( req.ID, r.routerID, @@ -198,25 +244,52 @@ func (r *LangRouter) ChatStream( } chunk := chunkResult.Chunk() - + buffer = append(buffer, *chunk) respC <- schemas.NewChatStreamChunk( req.ID, r.routerID, req.Metadata, chunk, ) + + if len(buffer) >= 1048 { // Define bufferSize as per your requirement + chunkKey := fmt.Sprintf("%s-chunk-%d", cacheKey, len(buffer)) + r.cache.Set(chunkKey, &schemas.StreamingCacheEntryChunk{ + Key: chunkKey, + Index: len(buffer), + Content: *chunk, + }) + streamingCacheEntry := schemas.StreamingCacheEntry{} + streamingCacheEntry.ResponseChunks = append(streamingCacheEntry.ResponseChunks, chunkKey) + buffer = buffer[:0] // Reset buffer + r.cache.Set(cacheKey, streamingCacheEntry) + } + } + + if len(buffer) > 0 { + chunkKey := fmt.Sprintf("%s-chunk-%d", cacheKey, len(buffer)) + r.cache.Set(chunkKey, &schemas.StreamingCacheEntryChunk{ + Key: chunkKey, + Index: len(buffer), + Content: buffer[0], // Assuming buffer has at least one element + }) + streamingCacheEntry := schemas.StreamingCacheEntry{} + streamingCacheEntry.ResponseChunks = append(streamingCacheEntry.ResponseChunks, chunkKey) + buffer = buffer[:0] // Reset buffer + r.cache.Set(cacheKey, streamingCacheEntry) } + streamingCacheEntry := schemas.StreamingCacheEntry{} + streamingCacheEntry.Complete = true + r.cache.Set(cacheKey, streamingCacheEntry) + return } - // no providers were available to handle the request, - // so we have to wait a bit with a hope there is some available next time r.logger.Warn("No healthy model found to serve streaming chat request, wait and retry") err := retryIterator.WaitNext(ctx) if err != nil { - // something has cancelled the context respC <- schemas.NewChatStreamError( req.ID, r.routerID, @@ -230,7 +303,6 @@ func (r *LangRouter) ChatStream( } } - // if we reach this part, then we are in trouble r.logger.Error( "No model was available to handle streaming chat request. " + "Try to configure more fallback models to avoid this",