diff --git a/internal/domain/interfaces.go b/internal/domain/interfaces.go index 86c162fe..16ea3496 100644 --- a/internal/domain/interfaces.go +++ b/internal/domain/interfaces.go @@ -103,12 +103,12 @@ type ConversationRepository interface { RemovePendingToolCallByID(toolCallID string) StartNewConversation(title string) error DeleteMessagesAfterIndex(index int) error + GetCurrentConversationTitle() string } // ConversationOptimizerService optimizes conversation history to reduce token usage type ConversationOptimizerService interface { - OptimizeMessages(messages []sdk.Message, force bool) []sdk.Message - OptimizeMessagesWithModel(messages []sdk.Message, currentModel string, force bool) []sdk.Message + OptimizeMessages(messages []sdk.Message, model string, force bool) []sdk.Message } // ModelService handles model selection and information diff --git a/internal/handlers/chat_shortcut_handler.go b/internal/handlers/chat_shortcut_handler.go index 66eb4f3a..7af00654 100644 --- a/internal/handlers/chat_shortcut_handler.go +++ b/internal/handlers/chat_shortcut_handler.go @@ -507,6 +507,8 @@ func (s *ChatShortcutHandler) performCompactAsync() tea.Cmd { logger.Info("Starting conversation compaction", "message_count", len(entries)) + originalTitle := s.handler.conversationRepo.GetCurrentConversationTitle() + messages := make([]sdk.Message, 0, len(entries)) for _, entry := range entries { if entry.Hidden { @@ -517,14 +519,19 @@ func (s *ChatShortcutHandler) performCompactAsync() tea.Cmd { currentModel := s.handler.modelService.GetCurrentModel() if currentModel == "" { - logger.Warn("No current model set for compaction - will use basic summary") + return domain.SetStatusEvent{ + Message: "No model selected - please select a model first", + Spinner: false, + StatusType: domain.StatusError, + } } - logger.Info("About to optimize conversation", "model", currentModel, "message_count", len(messages)) + + logger.Info("Optimizing conversation", "model", currentModel, "message_count", len(messages)) optimizedChan := make(chan []sdk.Message, 1) go func() { - result := s.handler.conversationOptimizer.OptimizeMessagesWithModel(messages, currentModel, true) - optimizedChan <- result + optimized := s.handler.conversationOptimizer.OptimizeMessages(messages, currentModel, true) + optimizedChan <- optimized }() var optimizedMessages []sdk.Message @@ -534,7 +541,7 @@ func (s *ChatShortcutHandler) performCompactAsync() tea.Cmd { case <-time.After(70 * time.Second): logger.Error("Optimization timed out after 70 seconds") return domain.SetStatusEvent{ - Message: "Conversation compaction timed out - try again or check gateway logs", + Message: "Conversation optimization timed out - try again or check gateway logs", Spinner: false, StatusType: domain.StatusError, } @@ -548,10 +555,11 @@ func (s *ChatShortcutHandler) performCompactAsync() tea.Cmd { } } - if clearErr := s.handler.conversationRepo.Clear(); clearErr != nil { - logger.Error("failed to clear conversation during compaction", "error", clearErr) + newTitle := fmt.Sprintf("Continued from %s", originalTitle) + if err := s.handler.conversationRepo.StartNewConversation(newTitle); err != nil { + logger.Error("Failed to start new conversation", "error", err) return domain.SetStatusEvent{ - Message: fmt.Sprintf("Failed to compact conversation: %v", clearErr), + Message: fmt.Sprintf("Failed to start new conversation: %v", err), Spinner: false, StatusType: domain.StatusError, } @@ -560,29 +568,14 @@ func (s *ChatShortcutHandler) performCompactAsync() tea.Cmd { for _, msg := range optimizedMessages { entry := domain.ConversationEntry{ Message: msg, + Model: currentModel, Time: time.Now(), } - if addErr := s.handler.conversationRepo.AddMessage(entry); addErr != nil { - logger.Error("failed to add optimized message during compaction", "error", addErr) + if err := s.handler.conversationRepo.AddMessage(entry); err != nil { + logger.Error("Failed to add optimized message", "error", err) } } - reduction := len(messages) - len(optimizedMessages) - reductionPercent := (float64(reduction) / float64(len(messages))) * 100 - - infoEntry := domain.ConversationEntry{ - Message: sdk.Message{ - Role: sdk.Assistant, - Content: sdk.NewMessageContent(fmt.Sprintf("Conversation compacted successfully! Reduced from %d to %d messages (%.1f%% reduction).", len(messages), len(optimizedMessages), reductionPercent)), - }, - Model: "", - Time: time.Now(), - } - - if addErr := s.handler.conversationRepo.AddMessage(infoEntry); addErr != nil { - logger.Error("failed to add compact info message", "error", addErr) - } - return tea.Batch( func() tea.Msg { return domain.UpdateHistoryEvent{ @@ -591,7 +584,7 @@ func (s *ChatShortcutHandler) performCompactAsync() tea.Cmd { }, func() tea.Msg { return domain.SetStatusEvent{ - Message: fmt.Sprintf("Conversation compacted: %d messages reduced to %d", len(messages), len(optimizedMessages)), + Message: fmt.Sprintf("• Started new conversation with summary (%d messages preserved)", len(messages)), Spinner: false, StatusType: domain.StatusDefault, } diff --git a/internal/infra/adapters/persistent_conversation_adapter.go b/internal/infra/adapters/persistent_conversation_adapter.go index a48f3c71..c3b74da3 100644 --- a/internal/infra/adapters/persistent_conversation_adapter.go +++ b/internal/infra/adapters/persistent_conversation_adapter.go @@ -71,7 +71,6 @@ func (a *PersistentConversationAdapter) GetCurrentConversationMetadata() shortcu CostStats: metadata.CostStats, Model: metadata.Model, Tags: metadata.Tags, - Summary: metadata.Summary, } } diff --git a/internal/infra/storage/interfaces.go b/internal/infra/storage/interfaces.go index 310d97eb..5a836cca 100644 --- a/internal/infra/storage/interfaces.go +++ b/internal/infra/storage/interfaces.go @@ -36,21 +36,19 @@ type ConversationStorage interface { // ConversationMetadata contains metadata about a conversation type ConversationMetadata struct { - ID string `json:"id"` - Title string `json:"title"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - MessageCount int `json:"message_count"` - TokenStats domain.SessionTokenStats `json:"token_stats"` - CostStats domain.SessionCostStats `json:"cost_stats,omitempty"` - Model string `json:"model,omitempty"` - Tags []string `json:"tags,omitempty"` - Summary string `json:"summary,omitempty"` - OptimizedMessages []domain.ConversationEntry `json:"optimized_messages,omitempty"` - TitleGenerated bool `json:"title_generated,omitempty"` - TitleInvalidated bool `json:"title_invalidated,omitempty"` - TitleGenerationTime *time.Time `json:"title_generation_time,omitempty"` - ContextID string `json:"context_id,omitempty"` + ID string `json:"id"` + Title string `json:"title"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + MessageCount int `json:"message_count"` + TokenStats domain.SessionTokenStats `json:"token_stats"` + CostStats domain.SessionCostStats `json:"cost_stats,omitempty"` + Model string `json:"model,omitempty"` + Tags []string `json:"tags,omitempty"` + TitleGenerated bool `json:"title_generated,omitempty"` + TitleInvalidated bool `json:"title_invalidated,omitempty"` + TitleGenerationTime *time.Time `json:"title_generation_time,omitempty"` + ContextID string `json:"context_id,omitempty"` } // ConversationSummary contains summary information about a conversation diff --git a/internal/infra/storage/jsonl.go b/internal/infra/storage/jsonl.go index 9e94c4c2..1af3f918 100644 --- a/internal/infra/storage/jsonl.go +++ b/internal/infra/storage/jsonl.go @@ -208,7 +208,6 @@ func (s *JsonlStorage) ListConversations(ctx context.Context, limit, offset int) CostStats: metadataWrapper.Metadata.CostStats, Model: metadataWrapper.Metadata.Model, Tags: metadataWrapper.Metadata.Tags, - Summary: metadataWrapper.Metadata.Summary, TitleGenerated: metadataWrapper.Metadata.TitleGenerated, TitleInvalidated: metadataWrapper.Metadata.TitleInvalidated, TitleGenerationTime: metadataWrapper.Metadata.TitleGenerationTime, diff --git a/internal/infra/storage/memory.go b/internal/infra/storage/memory.go index 2b744f0d..fa8b390c 100644 --- a/internal/infra/storage/memory.go +++ b/internal/infra/storage/memory.go @@ -81,7 +81,6 @@ func (m *MemoryStorage) ListConversations(ctx context.Context, limit, offset int TokenStats: data.metadata.TokenStats, Model: data.metadata.Model, Tags: data.metadata.Tags, - Summary: data.metadata.Summary, TitleGenerated: data.metadata.TitleGenerated, TitleInvalidated: data.metadata.TitleInvalidated, TitleGenerationTime: data.metadata.TitleGenerationTime, diff --git a/internal/infra/storage/migrations/postgres_migrations.go b/internal/infra/storage/migrations/postgres_migrations.go index ee8081d6..08d50902 100644 --- a/internal/infra/storage/migrations/postgres_migrations.go +++ b/internal/infra/storage/migrations/postgres_migrations.go @@ -15,8 +15,6 @@ func GetPostgresMigrations() []Migration { message_count INTEGER NOT NULL DEFAULT 0, model VARCHAR(255), tags JSONB, - summary TEXT, - optimized_messages JSONB, token_stats JSONB, cost_stats JSONB, title_generated BOOLEAN DEFAULT FALSE, diff --git a/internal/infra/storage/migrations/sqlite_migrations.go b/internal/infra/storage/migrations/sqlite_migrations.go index 7207e0fe..ac8736da 100644 --- a/internal/infra/storage/migrations/sqlite_migrations.go +++ b/internal/infra/storage/migrations/sqlite_migrations.go @@ -12,14 +12,12 @@ func GetSQLiteMigrations() []Migration { title TEXT NOT NULL, count INTEGER NOT NULL DEFAULT 0, messages TEXT NOT NULL, - optimized_messages TEXT, total_input_tokens INTEGER NOT NULL DEFAULT 0, total_output_tokens INTEGER NOT NULL DEFAULT 0, request_count INTEGER NOT NULL DEFAULT 0, cost_stats TEXT DEFAULT '{}', models TEXT DEFAULT '[]', tags TEXT DEFAULT '[]', - summary TEXT DEFAULT '', title_generated BOOLEAN DEFAULT FALSE, title_invalidated BOOLEAN DEFAULT FALSE, title_generation_time DATETIME, diff --git a/internal/infra/storage/postgres.go b/internal/infra/storage/postgres.go index a3f8ae6c..96304659 100644 --- a/internal/infra/storage/postgres.go +++ b/internal/infra/storage/postgres.go @@ -130,31 +130,21 @@ func (s *PostgresStorage) SaveConversation(ctx context.Context, conversationID s return fmt.Errorf("failed to marshal cost stats: %w", err) } - var optimizedMessagesJSON []byte - if len(metadata.OptimizedMessages) > 0 { - optimizedMessagesJSON, err = json.Marshal(metadata.OptimizedMessages) - if err != nil { - return fmt.Errorf("failed to marshal optimized messages: %w", err) - } - } - _, err = tx.ExecContext(ctx, ` - INSERT INTO conversations (id, title, created_at, updated_at, message_count, model, tags, summary, optimized_messages, token_stats, cost_stats, title_generated, title_invalidated, title_generation_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + INSERT INTO conversations (id, title, created_at, updated_at, message_count, model, tags, token_stats, cost_stats, title_generated, title_invalidated, title_generation_time) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ON CONFLICT(id) DO UPDATE SET title = EXCLUDED.title, updated_at = EXCLUDED.updated_at, message_count = EXCLUDED.message_count, model = EXCLUDED.model, tags = EXCLUDED.tags, - summary = EXCLUDED.summary, - optimized_messages = EXCLUDED.optimized_messages, token_stats = EXCLUDED.token_stats, cost_stats = EXCLUDED.cost_stats, title_generated = EXCLUDED.title_generated, title_invalidated = EXCLUDED.title_invalidated, title_generation_time = EXCLUDED.title_generation_time - `, conversationID, metadata.Title, metadata.CreatedAt, metadata.UpdatedAt, len(entries), metadata.Model, string(tagsJSON), metadata.Summary, optimizedMessagesJSON, string(tokenStatsJSON), string(costStatsJSON), metadata.TitleGenerated, metadata.TitleInvalidated, metadata.TitleGenerationTime) + `, conversationID, metadata.Title, metadata.CreatedAt, metadata.UpdatedAt, len(entries), metadata.Model, string(tagsJSON), string(tokenStatsJSON), string(costStatsJSON), metadata.TitleGenerated, metadata.TitleInvalidated, metadata.TitleGenerationTime) if err != nil { return fmt.Errorf("failed to save conversation metadata: %w", err) } @@ -186,15 +176,14 @@ func (s *PostgresStorage) SaveConversation(ctx context.Context, conversationID s func (s *PostgresStorage) LoadConversation(ctx context.Context, conversationID string) ([]domain.ConversationEntry, ConversationMetadata, error) { var metadata ConversationMetadata var tokenStatsJSON, tagsJSON, costStatsJSON string - var optimizedMessagesJSON sql.NullString err := s.db.QueryRowContext(ctx, ` - SELECT id, title, created_at, updated_at, message_count, model, tags, summary, optimized_messages, token_stats, COALESCE(cost_stats, '{}'), + SELECT id, title, created_at, updated_at, message_count, model, tags, token_stats, COALESCE(cost_stats, '{}'), COALESCE(title_generated, FALSE), COALESCE(title_invalidated, FALSE), title_generation_time FROM conversations WHERE id = $1 `, conversationID).Scan( &metadata.ID, &metadata.Title, &metadata.CreatedAt, &metadata.UpdatedAt, - &metadata.MessageCount, &metadata.Model, &tagsJSON, &metadata.Summary, &optimizedMessagesJSON, &tokenStatsJSON, &costStatsJSON, + &metadata.MessageCount, &metadata.Model, &tagsJSON, &tokenStatsJSON, &costStatsJSON, &metadata.TitleGenerated, &metadata.TitleInvalidated, &metadata.TitleGenerationTime, ) if err != nil { @@ -218,12 +207,6 @@ func (s *PostgresStorage) LoadConversation(ctx context.Context, conversationID s return nil, metadata, fmt.Errorf("failed to unmarshal tags: %w", err) } - if optimizedMessagesJSON.Valid && optimizedMessagesJSON.String != "" { - if err := json.Unmarshal([]byte(optimizedMessagesJSON.String), &metadata.OptimizedMessages); err != nil { - return nil, metadata, fmt.Errorf("failed to unmarshal optimized messages: %w", err) - } - } - rows, err := s.db.QueryContext(ctx, ` SELECT entry_data FROM conversation_entries WHERE conversation_id = $1 @@ -273,7 +256,7 @@ func (s *PostgresStorage) ListConversations(ctx context.Context, limit, offset i err := rows.Scan( &summary.ID, &summary.Title, &summary.CreatedAt, &summary.UpdatedAt, - &summary.MessageCount, &summary.Model, &tagsJSON, &summary.Summary, &tokenStatsJSON, &costStatsJSON, + &summary.MessageCount, &summary.Model, &tagsJSON, &tokenStatsJSON, &costStatsJSON, &summary.TitleGenerated, &summary.TitleInvalidated, &summary.TitleGenerationTime, ) if err != nil { @@ -323,7 +306,7 @@ func (s *PostgresStorage) ListConversationsNeedingTitles(ctx context.Context, li err := rows.Scan( &summary.ID, &summary.Title, &summary.CreatedAt, &summary.UpdatedAt, - &summary.MessageCount, &summary.Model, &tagsJSON, &summary.Summary, &tokenStatsJSON, &costStatsJSON, + &summary.MessageCount, &summary.Model, &tagsJSON, &tokenStatsJSON, &costStatsJSON, &summary.TitleGenerated, &summary.TitleInvalidated, &summary.TitleGenerationTime, ) if err != nil { @@ -388,10 +371,10 @@ func (s *PostgresStorage) UpdateConversationMetadata(ctx context.Context, conver result, err := s.db.ExecContext(ctx, ` UPDATE conversations - SET title = $1, updated_at = $2, model = $3, tags = $4, summary = $5, token_stats = $6, cost_stats = $7, - title_generated = $8, title_invalidated = $9, title_generation_time = $10 - WHERE id = $11 - `, metadata.Title, metadata.UpdatedAt, metadata.Model, string(tagsJSON), metadata.Summary, string(tokenStatsJSON), string(costStatsJSON), metadata.TitleGenerated, metadata.TitleInvalidated, metadata.TitleGenerationTime, conversationID) + SET title = $1, updated_at = $2, model = $3, tags = $4, token_stats = $5, cost_stats = $6, + title_generated = $7, title_invalidated = $8, title_generation_time = $9 + WHERE id = $10 + `, metadata.Title, metadata.UpdatedAt, metadata.Model, string(tagsJSON), string(tokenStatsJSON), string(costStatsJSON), metadata.TitleGenerated, metadata.TitleInvalidated, metadata.TitleGenerationTime, conversationID) if err != nil { return fmt.Errorf("failed to update conversation metadata: %w", err) } diff --git a/internal/infra/storage/redis.go b/internal/infra/storage/redis.go index 7163979b..6b3ca359 100644 --- a/internal/infra/storage/redis.go +++ b/internal/infra/storage/redis.go @@ -221,15 +221,15 @@ func (s *RedisStorage) ListConversations(ctx context.Context, limit, offset int) } summary := ConversationSummary{ - ID: metadata.ID, - Title: metadata.Title, - CreatedAt: metadata.CreatedAt, - UpdatedAt: metadata.UpdatedAt, - MessageCount: metadata.MessageCount, - TokenStats: metadata.TokenStats, - Model: metadata.Model, - Tags: metadata.Tags, - Summary: metadata.Summary, + ID: metadata.ID, + Title: metadata.Title, + CreatedAt: metadata.CreatedAt, + UpdatedAt: metadata.UpdatedAt, + MessageCount: metadata.MessageCount, + TokenStats: metadata.TokenStats, + Model: metadata.Model, + Tags: metadata.Tags, + TitleGenerated: metadata.TitleGenerated, TitleInvalidated: metadata.TitleInvalidated, TitleGenerationTime: metadata.TitleGenerationTime, @@ -288,15 +288,15 @@ func (s *RedisStorage) ListConversationsNeedingTitles(ctx context.Context, limit if (!metadata.TitleGenerated || metadata.TitleInvalidated) && metadata.MessageCount > 0 { summary := ConversationSummary{ - ID: metadata.ID, - Title: metadata.Title, - CreatedAt: metadata.CreatedAt, - UpdatedAt: metadata.UpdatedAt, - MessageCount: metadata.MessageCount, - TokenStats: metadata.TokenStats, - Model: metadata.Model, - Tags: metadata.Tags, - Summary: metadata.Summary, + ID: metadata.ID, + Title: metadata.Title, + CreatedAt: metadata.CreatedAt, + UpdatedAt: metadata.UpdatedAt, + MessageCount: metadata.MessageCount, + TokenStats: metadata.TokenStats, + Model: metadata.Model, + Tags: metadata.Tags, + TitleGenerated: metadata.TitleGenerated, TitleInvalidated: metadata.TitleInvalidated, TitleGenerationTime: metadata.TitleGenerationTime, diff --git a/internal/infra/storage/sqlite.go b/internal/infra/storage/sqlite.go index 90e8c002..c875f050 100644 --- a/internal/infra/storage/sqlite.go +++ b/internal/infra/storage/sqlite.go @@ -121,14 +121,6 @@ func (s *SQLiteStorage) SaveConversation(ctx context.Context, conversationID str return fmt.Errorf("failed to marshal tags: %w", err) } - var optimizedMessagesJSON []byte - if len(metadata.OptimizedMessages) > 0 { - optimizedMessagesJSON, err = json.Marshal(metadata.OptimizedMessages) - if err != nil { - return fmt.Errorf("failed to marshal optimized messages: %w", err) - } - } - totalInputTokens := metadata.TokenStats.TotalInputTokens totalOutputTokens := metadata.TokenStats.TotalOutputTokens requestCount := metadata.TokenStats.RequestCount @@ -138,35 +130,27 @@ func (s *SQLiteStorage) SaveConversation(ctx context.Context, conversationID str return fmt.Errorf("failed to marshal cost stats: %w", err) } - var optimizedMessagesStr *string - if len(optimizedMessagesJSON) > 0 { - str := string(optimizedMessagesJSON) - optimizedMessagesStr = &str - } - _, err = s.db.ExecContext(ctx, ` - INSERT INTO conversations (id, title, count, messages, optimized_messages, total_input_tokens, total_output_tokens, - request_count, cost_stats, models, tags, summary, title_generated, title_invalidated, title_generation_time, + INSERT INTO conversations (id, title, count, messages, total_input_tokens, total_output_tokens, + request_count, cost_stats, models, tags, title_generated, title_invalidated, title_generation_time, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET title = excluded.title, count = excluded.count, messages = excluded.messages, - optimized_messages = excluded.optimized_messages, total_input_tokens = excluded.total_input_tokens, total_output_tokens = excluded.total_output_tokens, request_count = excluded.request_count, cost_stats = excluded.cost_stats, models = excluded.models, tags = excluded.tags, - summary = excluded.summary, title_generated = excluded.title_generated, title_invalidated = excluded.title_invalidated, title_generation_time = excluded.title_generation_time, updated_at = excluded.updated_at - `, conversationID, metadata.Title, len(entries), string(messagesJSON), optimizedMessagesStr, totalInputTokens, totalOutputTokens, - requestCount, string(costStatsJSON), string(modelsJSON), string(tagsJSON), metadata.Summary, metadata.TitleGenerated, metadata.TitleInvalidated, + `, conversationID, metadata.Title, len(entries), string(messagesJSON), totalInputTokens, totalOutputTokens, + requestCount, string(costStatsJSON), string(modelsJSON), string(tagsJSON), metadata.TitleGenerated, metadata.TitleInvalidated, metadata.TitleGenerationTime, metadata.CreatedAt.Format(time.RFC3339), metadata.UpdatedAt.Format(time.RFC3339)) if err != nil { return fmt.Errorf("failed to save conversation: %w", err) @@ -177,17 +161,11 @@ func (s *SQLiteStorage) SaveConversation(ctx context.Context, conversationID str // LoadConversation loads a conversation by its ID using simplified schema func (s *SQLiteStorage) LoadConversation(ctx context.Context, conversationID string) ([]domain.ConversationEntry, ConversationMetadata, error) { - metadata, messagesJSON, optimizedMessagesJSON, err := s.loadConversationMetadata(ctx, conversationID) + metadata, messagesJSON, err := s.loadConversationMetadata(ctx, conversationID) if err != nil { return nil, metadata, err } - if optimizedMessagesJSON.Valid && optimizedMessagesJSON.String != "" { - if err := json.Unmarshal([]byte(optimizedMessagesJSON.String), &metadata.OptimizedMessages); err != nil { - return nil, metadata, fmt.Errorf("failed to unmarshal optimized messages: %w", err) - } - } - var entries []domain.ConversationEntry if err := json.Unmarshal([]byte(messagesJSON), &entries); err != nil { return nil, metadata, fmt.Errorf("failed to unmarshal messages: %w", err) @@ -197,30 +175,29 @@ func (s *SQLiteStorage) LoadConversation(ctx context.Context, conversationID str } // loadConversationMetadata loads the metadata for a conversation -func (s *SQLiteStorage) loadConversationMetadata(ctx context.Context, conversationID string) (ConversationMetadata, string, sql.NullString, error) { +func (s *SQLiteStorage) loadConversationMetadata(ctx context.Context, conversationID string) (ConversationMetadata, string, error) { var metadata ConversationMetadata var messagesJSON, modelsJSON, tagsJSON, costStatsJSON string - var optimizedMessagesJSON sql.NullString var totalInputTokens, totalOutputTokens, requestCount int var titleGenerationTime sql.NullTime err := s.db.QueryRowContext(ctx, ` - SELECT id, title, count, messages, optimized_messages, total_input_tokens, total_output_tokens, - request_count, cost_stats, models, tags, summary, title_generated, title_invalidated, title_generation_time, + SELECT id, title, count, messages, total_input_tokens, total_output_tokens, + request_count, cost_stats, models, tags, title_generated, title_invalidated, title_generation_time, created_at, updated_at FROM conversations WHERE id = ? `, conversationID).Scan( &metadata.ID, &metadata.Title, &metadata.MessageCount, - &messagesJSON, &optimizedMessagesJSON, &totalInputTokens, &totalOutputTokens, - &requestCount, &costStatsJSON, &modelsJSON, &tagsJSON, &metadata.Summary, + &messagesJSON, &totalInputTokens, &totalOutputTokens, + &requestCount, &costStatsJSON, &modelsJSON, &tagsJSON, &metadata.TitleGenerated, &metadata.TitleInvalidated, &titleGenerationTime, &metadata.CreatedAt, &metadata.UpdatedAt, ) if err != nil { if err == sql.ErrNoRows { - return metadata, "", optimizedMessagesJSON, fmt.Errorf("conversation not found: %s", conversationID) + return metadata, "", fmt.Errorf("conversation not found: %s", conversationID) } - return metadata, "", optimizedMessagesJSON, fmt.Errorf("failed to load conversation: %w", err) + return metadata, "", fmt.Errorf("failed to load conversation: %w", err) } metadata.TokenStats = domain.SessionTokenStats{ @@ -251,7 +228,7 @@ func (s *SQLiteStorage) loadConversationMetadata(ctx context.Context, conversati metadata.TitleGenerationTime = &titleGenerationTime.Time } - return metadata, messagesJSON, optimizedMessagesJSON, nil + return metadata, messagesJSON, nil } // ListConversations returns a list of conversation summaries @@ -304,7 +281,7 @@ func (s *SQLiteStorage) ListConversations(ctx context.Context, limit, offset int func (s *SQLiteStorage) ListConversationsNeedingTitles(ctx context.Context, limit int) ([]ConversationSummary, error) { rows, err := s.db.QueryContext(ctx, ` SELECT id, title, created_at, updated_at, count, total_input_tokens, total_output_tokens, - models, tags, summary, title_generated, title_invalidated, title_generation_time + models, tags, title_generated, title_invalidated, title_generation_time FROM conversations WHERE (title_generated = FALSE OR title_invalidated = TRUE) AND count >= 2 -- Only conversations with at least 2 messages (user + assistant) @@ -326,7 +303,7 @@ func (s *SQLiteStorage) ListConversationsNeedingTitles(ctx context.Context, limi err := rows.Scan( &summary.ID, &summary.Title, &summary.CreatedAt, &summary.UpdatedAt, &summary.MessageCount, &totalInputTokens, &totalOutputTokens, - &modelsJSON, &tagsJSON, &summary.Summary, + &modelsJSON, &tagsJSON, &summary.TitleGenerated, &summary.TitleInvalidated, &titleGenerationTime, ) if err != nil { @@ -402,11 +379,11 @@ func (s *SQLiteStorage) UpdateConversationMetadata(ctx context.Context, conversa result, err := s.db.ExecContext(ctx, ` UPDATE conversations - SET title = ?, updated_at = ?, models = ?, tags = ?, summary = ?, + SET title = ?, updated_at = ?, models = ?, tags = ?, total_input_tokens = ?, total_output_tokens = ?, request_count = ?, cost_stats = ?, title_generated = ?, title_invalidated = ?, title_generation_time = ? WHERE id = ? - `, metadata.Title, metadata.UpdatedAt, modelsJSON, string(tagsJSON), metadata.Summary, + `, metadata.Title, metadata.UpdatedAt, modelsJSON, string(tagsJSON), metadata.TokenStats.TotalInputTokens, metadata.TokenStats.TotalOutputTokens, metadata.TokenStats.RequestCount, string(costStatsJSON), metadata.TitleGenerated, metadata.TitleInvalidated, metadata.TitleGenerationTime, conversationID) if err != nil { diff --git a/internal/infra/storage/sqlite_test.go b/internal/infra/storage/sqlite_test.go index 7063efbc..23ef8773 100644 --- a/internal/infra/storage/sqlite_test.go +++ b/internal/infra/storage/sqlite_test.go @@ -165,7 +165,6 @@ func TestSQLiteStorage_ConversationManagement(t *testing.T) { metadata.Title = "New Title" metadata.Tags = []string{"updated", "test"} - metadata.Summary = "Updated summary" metadata.UpdatedAt = time.Now() err = storage.UpdateConversationMetadata(ctx, conversationID, metadata) @@ -176,7 +175,6 @@ func TestSQLiteStorage_ConversationManagement(t *testing.T) { assert.Equal(t, "New Title", loadedMetadata.Title) assert.Equal(t, []string{"updated", "test"}, loadedMetadata.Tags) - assert.Equal(t, "Updated summary", loadedMetadata.Summary) }) } @@ -250,8 +248,7 @@ func createTestMetadata(id string) ConversationMetadata { TotalTokens: 250, RequestCount: 2, }, - Model: "claude-4", - Tags: []string{"test", "demo"}, - Summary: "A test conversation", + Model: "claude-4", + Tags: []string{"test", "demo"}, } } diff --git a/internal/services/agent.go b/internal/services/agent.go index 8d3104cb..866ca938 100644 --- a/internal/services/agent.go +++ b/internal/services/agent.go @@ -248,7 +248,7 @@ func (s *AgentServiceImpl) Run(ctx context.Context, req *domain.AgentRequest) (* optimizedMessages := req.Messages if s.optimizer != nil { - optimizedMessages = s.optimizer.OptimizeMessagesWithModel(req.Messages, req.Model, false) + optimizedMessages = s.optimizer.OptimizeMessages(req.Messages, req.Model, false) } messages := s.addSystemPrompt(optimizedMessages) @@ -759,19 +759,9 @@ func (s *AgentServiceImpl) optimizeConversation(ctx context.Context, req *domain originalCount := len(conversation) - persistentRepo, isPersistent := s.conversationRepo.(*PersistentConversationRepository) - if isPersistent { - if cachedMessages := persistentRepo.GetOptimizedMessages(); len(cachedMessages) > 0 { - if len(conversation) <= len(cachedMessages) { - return cachedMessages - } - conversation = append(cachedMessages, conversation[len(cachedMessages):]...) - } - } - eventPublisher.publishOptimizationStatus("Optimizing conversation history...", true, originalCount, originalCount) - conversation = s.optimizer.OptimizeMessagesWithModel(conversation, req.Model, false) + conversation = s.optimizer.OptimizeMessages(conversation, req.Model, false) optimizedCount := len(conversation) var message string @@ -783,12 +773,6 @@ func (s *AgentServiceImpl) optimizeConversation(ctx context.Context, req *domain eventPublisher.publishOptimizationStatus(message, false, originalCount, optimizedCount) - if isPersistent { - if err := persistentRepo.SetOptimizedMessages(ctx, conversation); err != nil { - logger.Error("Failed to save optimized conversation", "error", err) - } - } - return conversation } diff --git a/internal/services/conversation.go b/internal/services/conversation.go index 0ab3bafd..841beeb0 100644 --- a/internal/services/conversation.go +++ b/internal/services/conversation.go @@ -499,3 +499,8 @@ func (r *InMemoryConversationRepository) FormatToolResultExpanded(result *domain } return "Tool execution failed" } + +// GetCurrentConversationTitle returns the current conversation title +func (r *InMemoryConversationRepository) GetCurrentConversationTitle() string { + return "New Conversation" +} diff --git a/internal/services/conversation_optimizer.go b/internal/services/conversation_optimizer.go index a04454cd..9ffb4dc5 100644 --- a/internal/services/conversation_optimizer.go +++ b/internal/services/conversation_optimizer.go @@ -64,13 +64,8 @@ func NewConversationOptimizer(config OptimizerConfig) *ConversationOptimizer { } } -// OptimizeMessages reduces token usage by intelligently managing conversation history -func (co *ConversationOptimizer) OptimizeMessages(messages []sdk.Message, force bool) []sdk.Message { - return co.OptimizeMessagesWithModel(messages, "", force) -} - -// OptimizeMessagesWithModel reduces token usage with optional current model for fallback -func (co *ConversationOptimizer) OptimizeMessagesWithModel(messages []sdk.Message, currentModel string, force bool) []sdk.Message { +// OptimizeMessages reduces token usage by intelligently managing conversation history with LLM summarization +func (co *ConversationOptimizer) OptimizeMessages(messages []sdk.Message, model string, force bool) []sdk.Message { if len(messages) == 0 { return messages } @@ -81,7 +76,7 @@ func (co *ConversationOptimizer) OptimizeMessagesWithModel(messages []sdk.Messag currentTokens := co.tokenizer.EstimateMessagesTokens(messages) - contextWindow := models.EstimateContextWindow(currentModel) + contextWindow := models.EstimateContextWindow(model) if contextWindow == 0 { contextWindow = 30000 } @@ -103,7 +98,7 @@ func (co *ConversationOptimizer) OptimizeMessagesWithModel(messages []sdk.Messag } } - optimized, err := co.smartOptimize(conversationMessages, currentModel) + optimized, err := co.smartOptimize(conversationMessages, model) if err != nil { logger.Error("Optimization failed", "error", err) // If optimization fails, return original messages @@ -145,7 +140,7 @@ func (co *ConversationOptimizer) smartOptimize(messages []sdk.Message, model str } logger.Info("Generating LLM summary for compaction", "model", model, "messages_to_summarize", len(messagesToSummarize)) - summary, err := co.generateLLMSummary(messagesToSummarize, model) + summary, err := co.GenerateLLMSummary(messagesToSummarize, model) if err != nil { logger.Error("Failed to generate LLM summary", "error", err) return nil, fmt.Errorf("failed to generate summary: %w", err) @@ -205,8 +200,10 @@ func (co *ConversationOptimizer) adjustBoundaryForToolCallsAtStart(messages []sd return adjustedBoundary } -// generateLLMSummary uses the SDK client to generate an intelligent summary -func (co *ConversationOptimizer) generateLLMSummary(messages []sdk.Message, model string) (string, error) { +// GenerateLLMSummary creates a concise summary of conversation messages using an LLM. +// It uses the SDK client to generate an intelligent summary focused on key tasks, +// decisions, critical context, and next steps. The summary is limited to 2-3 sentences. +func (co *ConversationOptimizer) GenerateLLMSummary(messages []sdk.Message, model string) (string, error) { ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() diff --git a/internal/services/conversation_optimizer_test.go b/internal/services/conversation_optimizer_test.go index 309d92f3..18b70b3f 100644 --- a/internal/services/conversation_optimizer_test.go +++ b/internal/services/conversation_optimizer_test.go @@ -31,7 +31,7 @@ func TestConversationOptimizer_ToolCallIntegrity_BufferBoundary(t *testing.T) { } optimizer := createTestOptimizer(2) - result := optimizer.OptimizeMessages(messages, false) + result := optimizer.OptimizeMessages(messages, "deepseek/deepseek-chat", false) validateAssistantToolCallsPreserved(t, result) validateNoOrphanedToolCalls(t, result) @@ -55,7 +55,7 @@ func TestConversationOptimizer_ToolCallIntegrity_ToolResponseInBuffer(t *testing } optimizer := createTestOptimizer(2) - result := optimizer.OptimizeMessages(messages, false) + result := optimizer.OptimizeMessages(messages, "deepseek/deepseek-chat", false) validateToolResponseHasAssistant(t, result) validateNoOrphanedToolCalls(t, result) @@ -88,7 +88,7 @@ func TestConversationOptimizer_ToolCallIntegrity_MultipleGroups(t *testing.T) { } optimizer := createTestOptimizer(3) - result := optimizer.OptimizeMessages(messages, false) + result := optimizer.OptimizeMessages(messages, "deepseek/deepseek-chat", false) validateNoOrphanedToolCalls(t, result) } @@ -110,7 +110,7 @@ func TestConversationOptimizer_ToolCallIntegrity_ExactBufferStart(t *testing.T) } optimizer := createTestOptimizer(2) - result := optimizer.OptimizeMessages(messages, false) + result := optimizer.OptimizeMessages(messages, "deepseek/deepseek-chat", false) validateNoOrphanedToolCalls(t, result) } @@ -129,7 +129,7 @@ func TestConversationOptimizer_ToolCallIntegrity_NoToolCalls(t *testing.T) { } optimizer := createTestOptimizer(2) - result := optimizer.OptimizeMessages(messages, false) + result := optimizer.OptimizeMessages(messages, "deepseek/deepseek-chat", false) validateNoOrphanedToolCalls(t, result) assert.NotEmpty(t, result) @@ -155,7 +155,7 @@ func TestConversationOptimizer_ToolCallIntegrity_PartialResponses(t *testing.T) } optimizer := createTestOptimizer(1) - result := optimizer.OptimizeMessages(messages, false) + result := optimizer.OptimizeMessages(messages, "deepseek/deepseek-chat", false) validateNoOrphanedToolCalls(t, result) } @@ -219,7 +219,7 @@ func TestConversationOptimizer_EdgeCases(t *testing.T) { Tokenizer: nil, }) - result := optimizer.OptimizeMessages(tt.messages, false) + result := optimizer.OptimizeMessages(tt.messages, "anthropic/claude-3-5-sonnet-20241022", false) if tt.expectEmpty { assert.Empty(t, result) @@ -251,7 +251,7 @@ func TestConversationOptimizer_DisabledOptimization(t *testing.T) { Tokenizer: nil, }) - result := optimizer.OptimizeMessages(messages, false) + result := optimizer.OptimizeMessages(messages, "deepseek/deepseek-chat", false) assert.Equal(t, len(messages), len(result)) assert.Equal(t, messages, result) diff --git a/internal/services/persistent_conversation.go b/internal/services/persistent_conversation.go index 7232cc5a..4c931021 100644 --- a/internal/services/persistent_conversation.go +++ b/internal/services/persistent_conversation.go @@ -340,55 +340,6 @@ func (r *PersistentConversationRepository) AddTokenUsage(model string, inputToke return nil } -// GetOptimizedMessages retrieves the stored optimized conversation messages -func (r *PersistentConversationRepository) GetOptimizedMessages() []sdk.Message { - if len(r.metadata.OptimizedMessages) == 0 { - return nil - } - - optimizedMessages := make([]sdk.Message, 0, len(r.metadata.OptimizedMessages)) - for _, entry := range r.metadata.OptimizedMessages { - optimizedMessages = append(optimizedMessages, sdk.Message{ - Role: entry.Message.Role, - Content: entry.Message.Content, - ToolCalls: entry.Message.ToolCalls, - ToolCallId: entry.Message.ToolCallId, - }) - } - return optimizedMessages -} - -// SetOptimizedMessages stores the optimized conversation messages -func (r *PersistentConversationRepository) SetOptimizedMessages(ctx context.Context, optimizedMessages []sdk.Message) error { - if r.conversationID == "" { - return fmt.Errorf("no active conversation to store optimized messages") - } - - r.autoSaveMutex.Lock() - defer r.autoSaveMutex.Unlock() - - conversationEntries := make([]domain.ConversationEntry, 0, len(optimizedMessages)) - now := time.Now() - - for _, msg := range optimizedMessages { - entry := domain.ConversationEntry{ - Message: domain.Message{ - Role: msg.Role, - Content: msg.Content, - ToolCalls: msg.ToolCalls, - ToolCallId: msg.ToolCallId, - }, - Time: now, - } - conversationEntries = append(conversationEntries, entry) - } - - r.metadata.OptimizedMessages = conversationEntries - r.metadata.UpdatedAt = now - - return r.storage.UpdateConversationMetadata(ctx, r.conversationID, r.metadata) -} - // Close closes the storage connection func (r *PersistentConversationRepository) Close() error { if r.storage != nil { @@ -396,3 +347,8 @@ func (r *PersistentConversationRepository) Close() error { } return nil } + +// GetCurrentConversationTitle returns the current conversation title +func (r *PersistentConversationRepository) GetCurrentConversationTitle() string { + return r.metadata.Title +} diff --git a/internal/shortcuts/core.go b/internal/shortcuts/core.go index 3da66739..81e4fb87 100644 --- a/internal/shortcuts/core.go +++ b/internal/shortcuts/core.go @@ -59,7 +59,7 @@ func NewCompactShortcut(repo domain.ConversationRepository) *CompactShortcut { func (c *CompactShortcut) GetName() string { return "compact" } func (c *CompactShortcut) GetDescription() string { - return "Optimize conversation to reduce token usage" + return "Save current conversation and start new session with summary" } func (c *CompactShortcut) GetUsage() string { return "/compact" } func (c *CompactShortcut) CanExecute(args []string) bool { return len(args) == 0 } diff --git a/tests/mocks/domain/fake_conversation_repository.go b/tests/mocks/domain/fake_conversation_repository.go index d2f89814..a8105919 100644 --- a/tests/mocks/domain/fake_conversation_repository.go +++ b/tests/mocks/domain/fake_conversation_repository.go @@ -113,6 +113,16 @@ type FakeConversationRepository struct { formatToolResultForUIReturnsOnCall map[int]struct { result1 string } + GetCurrentConversationTitleStub func() string + getCurrentConversationTitleMutex sync.RWMutex + getCurrentConversationTitleArgsForCall []struct { + } + getCurrentConversationTitleReturns struct { + result1 string + } + getCurrentConversationTitleReturnsOnCall map[int]struct { + result1 string + } GetMessageCountStub func() int getMessageCountMutex sync.RWMutex getMessageCountArgsForCall []struct { @@ -736,6 +746,59 @@ func (fake *FakeConversationRepository) FormatToolResultForUIReturnsOnCall(i int }{result1} } +func (fake *FakeConversationRepository) GetCurrentConversationTitle() string { + fake.getCurrentConversationTitleMutex.Lock() + ret, specificReturn := fake.getCurrentConversationTitleReturnsOnCall[len(fake.getCurrentConversationTitleArgsForCall)] + fake.getCurrentConversationTitleArgsForCall = append(fake.getCurrentConversationTitleArgsForCall, struct { + }{}) + stub := fake.GetCurrentConversationTitleStub + fakeReturns := fake.getCurrentConversationTitleReturns + fake.recordInvocation("GetCurrentConversationTitle", []interface{}{}) + fake.getCurrentConversationTitleMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeConversationRepository) GetCurrentConversationTitleCallCount() int { + fake.getCurrentConversationTitleMutex.RLock() + defer fake.getCurrentConversationTitleMutex.RUnlock() + return len(fake.getCurrentConversationTitleArgsForCall) +} + +func (fake *FakeConversationRepository) GetCurrentConversationTitleCalls(stub func() string) { + fake.getCurrentConversationTitleMutex.Lock() + defer fake.getCurrentConversationTitleMutex.Unlock() + fake.GetCurrentConversationTitleStub = stub +} + +func (fake *FakeConversationRepository) GetCurrentConversationTitleReturns(result1 string) { + fake.getCurrentConversationTitleMutex.Lock() + defer fake.getCurrentConversationTitleMutex.Unlock() + fake.GetCurrentConversationTitleStub = nil + fake.getCurrentConversationTitleReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeConversationRepository) GetCurrentConversationTitleReturnsOnCall(i int, result1 string) { + fake.getCurrentConversationTitleMutex.Lock() + defer fake.getCurrentConversationTitleMutex.Unlock() + fake.GetCurrentConversationTitleStub = nil + if fake.getCurrentConversationTitleReturnsOnCall == nil { + fake.getCurrentConversationTitleReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.getCurrentConversationTitleReturnsOnCall[i] = struct { + result1 string + }{result1} +} + func (fake *FakeConversationRepository) GetMessageCount() int { fake.getMessageCountMutex.Lock() ret, specificReturn := fake.getMessageCountReturnsOnCall[len(fake.getMessageCountArgsForCall)]