diff --git a/go.mod b/go.mod index b4e9cce3..3483540d 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/mark3labs/mcp-go v0.31.0 github.com/spf13/cobra v1.9.1 github.com/spf13/pflag v1.0.6 + github.com/stretchr/testify v1.10.0 go.uber.org/mock v0.6.0 golang.org/x/sync v0.16.0 golang.org/x/term v0.31.0 @@ -61,6 +62,7 @@ require ( github.com/charmbracelet/x/cellbuf v0.0.13 // indirect github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect github.com/charmbracelet/x/term v0.2.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dlclark/regexp2 v1.11.4 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -87,6 +89,7 @@ require ( github.com/ollama/ollama v0.6.5 // indirect github.com/openai/openai-go v1.11.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect github.com/spf13/cast v1.7.1 // indirect @@ -111,4 +114,5 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2 // indirect google.golang.org/grpc v1.70.0 // indirect google.golang.org/protobuf v1.36.5 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/gollm/azopenai.go b/gollm/azopenai.go index 512fc7cd..fb135725 100644 --- a/gollm/azopenai.go +++ b/gollm/azopenai.go @@ -271,10 +271,55 @@ func (c *AzureOpenAIChat) IsRetryableError(err error) bool { } func (c *AzureOpenAIChat) Initialize(messages []*api.Message) error { - klog.Warning("chat history persistence is not supported for provider 'azopenai', using in-memory chat history") + klog.Info("Initializing Azure OpenAI chat with history") + c.history = make([]azopenai.ChatRequestMessageClassification, 0, len(messages)) + for _, msg := range messages { + content, err := c.messageToAzureContent(msg) + if err != nil { + continue // Skip malformed messages but continue processing + } + c.history = append(c.history, content) + } return nil } +func (c *AzureOpenAIChat) messageToAzureContent(msg *api.Message) (azopenai.ChatRequestMessageClassification, error) { + var role string + switch msg.Source { + case api.MessageSourceUser: + role = "user" + case api.MessageSourceModel: + role = "assistant" + case api.MessageSourceAgent: + role = "user" // Treat agent messages as user messages + default: + return nil, fmt.Errorf("unknown message source: %s", msg.Source) + } + + switch v := msg.Payload.(type) { + case string: + if role == "user" { + return &azopenai.ChatRequestUserMessage{ + Content: azopenai.NewChatRequestUserMessageContent(v), + }, nil + } else { + return &azopenai.ChatRequestAssistantMessage{ + Content: azopenai.NewChatRequestAssistantMessageContent(v), + }, nil + } + case FunctionCallResult: + // Handle function call results appropriately + return &azopenai.ChatRequestUserMessage{ + Content: azopenai.NewChatRequestUserMessageContent(fmt.Sprintf("Function call result: %s", v.Result)), + }, nil + default: + // Convert unknown types to string representation + return &azopenai.ChatRequestUserMessage{ + Content: azopenai.NewChatRequestUserMessageContent(fmt.Sprintf("%v", v)), + }, nil + } +} + func (c *AzureOpenAIChat) SendStreaming(ctx context.Context, contents ...any) (ChatResponseIterator, error) { // TODO: Implement streaming response, err := c.Send(ctx, contents...) diff --git a/gollm/grok.go b/gollm/grok.go index 8271c317..c9461778 100644 --- a/gollm/grok.go +++ b/gollm/grok.go @@ -381,10 +381,67 @@ func (cs *grokChatSession) IsRetryableError(err error) bool { } func (cs *grokChatSession) Initialize(messages []*api.Message) error { - klog.Warning("chat history persistence is not supported for provider 'grok', using in-memory chat history") + klog.Info("Initializing Grok chat with history") + cs.history = make([]openai.ChatCompletionMessageParamUnion, 0, len(messages)) + for _, msg := range messages { + content, err := cs.messageToGrokProvider(msg) + if err != nil { + continue // Skip malformed messages but continue processing + } + cs.history = append(cs.history, content) + } return nil } +// messageToGrokProvider converts api.Message to Grok-specific message format +func (cs *grokChatSession) messageToGrokProvider(msg *api.Message) (openai.ChatCompletionMessageParamUnion, error) { + var role string + switch msg.Source { + case api.MessageSourceUser: + role = "user" + case api.MessageSourceModel: + role = "assistant" + case api.MessageSourceAgent: + role = "system" // Grok treats agent messages as system messages + default: + return openai.UserMessage(""), fmt.Errorf("unknown message source: %s", msg.Source) + } + + switch v := msg.Payload.(type) { + case string: + switch role { + case "user": + return openai.UserMessage(v), nil + case "assistant": + return openai.AssistantMessage(v), nil + case "system": + return openai.SystemMessage(v), nil + default: + return openai.UserMessage(v), nil + } + case FunctionCallResult: + // Handle function call results as tool messages for Grok + resultJSON, err := json.Marshal(v.Result) + if err != nil { + return openai.UserMessage(""), fmt.Errorf("failed to marshal function call result: %w", err) + } + return openai.ToolMessage(string(resultJSON), v.ID), nil + default: + // Convert unknown types to string representation + content := fmt.Sprintf("%v", v) + switch role { + case "user": + return openai.UserMessage(content), nil + case "assistant": + return openai.AssistantMessage(content), nil + case "system": + return openai.SystemMessage(content), nil + default: + return openai.UserMessage(content), nil + } + } +} + // --- Helper structs for ChatResponse interface --- type grokChatResponse struct { diff --git a/gollm/llamacpp.go b/gollm/llamacpp.go index 33682024..6f0c08c4 100644 --- a/gollm/llamacpp.go +++ b/gollm/llamacpp.go @@ -296,10 +296,50 @@ func (c *LlamaCppChat) IsRetryableError(err error) bool { } func (c *LlamaCppChat) Initialize(messages []*api.Message) error { - klog.Warning("chat history persistence is not supported for provider 'llamacpp', using in-memory chat history") + klog.Info("Initializing llama.cpp chat with history") + c.history = make([]llamacppChatMessage, 0, len(messages)) + for _, msg := range messages { + content, err := c.messageToLlamaCppContent(msg) + if err != nil { + // Skip malformed messages but continue processing + continue + } + c.history = append(c.history, content) + } return nil } +func (c *LlamaCppChat) messageToLlamaCppContent(msg *api.Message) (llamacppChatMessage, error) { + var role string + switch msg.Source { + case api.MessageSourceUser: + role = "user" + case api.MessageSourceModel: + role = "assistant" + case api.MessageSourceAgent: + // Treat agent messages as system messages to seed context + role = "system" + default: + return llamacppChatMessage{}, fmt.Errorf("unknown message source: %s", msg.Source) + } + + switch v := msg.Payload.(type) { + case string: + return llamacppChatMessage{Role: role, Content: ptrTo(v)}, nil + case FunctionCallResult: + // Represent function call results as tool message with JSON content + resultJSON, err := json.Marshal(v.Result) + if err != nil { + return llamacppChatMessage{}, fmt.Errorf("failed to marshal function call result: %w", err) + } + return llamacppChatMessage{Role: "tool", Content: ptrTo(string(resultJSON)), ToolCallID: v.ID}, nil + default: + // Convert unknown types to string representation + content := fmt.Sprintf("%v", v) + return llamacppChatMessage{Role: role, Content: ptrTo(content)}, nil + } +} + func ptrTo[T any](t T) *T { return &t } diff --git a/gollm/ollama.go b/gollm/ollama.go index 720288d5..19a4c041 100644 --- a/gollm/ollama.go +++ b/gollm/ollama.go @@ -210,10 +210,51 @@ func (c *OllamaChat) SendStreaming(ctx context.Context, contents ...any) (ChatRe } func (c *OllamaChat) Initialize(messages []*kctlApi.Message) error { - klog.Warning("chat history persistence is not supported for provider 'ollama', using in-memory chat history") + klog.Info("Initializing ollama chat with history") + c.history = make([]api.Message, 0, len(messages)) + for _, msg := range messages { + content, err := c.messageToOllamaContent(msg) + if err != nil { + // Skip malformed messages but continue processing + continue + } + c.history = append(c.history, content) + } return nil } +func (c *OllamaChat) messageToOllamaContent(msg *kctlApi.Message) (api.Message, error) { + var role string + switch msg.Source { + case kctlApi.MessageSourceUser: + role = "user" + case kctlApi.MessageSourceModel: + role = "assistant" + case kctlApi.MessageSourceAgent: + // Treat agent messages as system to seed context + role = "system" + default: + return api.Message{}, fmt.Errorf("unknown message source: %s", msg.Source) + } + + switch v := msg.Payload.(type) { + case string: + return api.Message{Role: role, Content: v}, nil + case FunctionCallResult: + // Represent tool output as a tool response; Ollama does not have a distinct tool role in history API, + // so include a textual representation for context. + resultJSON, err := json.Marshal(v.Result) + if err != nil { + return api.Message{}, fmt.Errorf("failed to marshal function call result: %w", err) + } + return api.Message{Role: "user", Content: string(resultJSON)}, nil + default: + // Convert unknown types to string representation + content := fmt.Sprintf("%v", v) + return api.Message{Role: role, Content: content}, nil + } +} + type OllamaChatResponse struct { candidates []*OllamaCandidate ollamaResponse api.ChatResponse diff --git a/gollm/openai.go b/gollm/openai.go index e26dd86c..3824be43 100644 --- a/gollm/openai.go +++ b/gollm/openai.go @@ -424,10 +424,56 @@ func (cs *openAIChatSession) IsRetryableError(err error) bool { } func (cs *openAIChatSession) Initialize(messages []*api.Message) error { - klog.Warning("chat history persistence is not supported for provider 'openai', using in-memory chat history") + klog.Info("Initializing OpenAI chat with history") + cs.history = make([]openai.ChatCompletionMessageParamUnion, 0, len(messages)) + for _, msg := range messages { + content, err := cs.messageToOpenAIContent(msg) + if err != nil { + continue // Skip malformed messages but continue processing + } + cs.history = append(cs.history, content) + } return nil } +func (cs *openAIChatSession) messageToOpenAIContent(msg *api.Message) (openai.ChatCompletionMessageParamUnion, error) { + var role string + switch msg.Source { + case api.MessageSourceUser: + role = "user" + case api.MessageSourceModel: + role = "assistant" + case api.MessageSourceAgent: + role = "agent" + default: + return openai.UserMessage(""), fmt.Errorf("unknown message source: %s", msg.Source) + } + + switch v := msg.Payload.(type) { + case string: + if role == "user" { + return openai.UserMessage(v), nil + } else { + return openai.AssistantMessage(v), nil + } + case FunctionCallResult: + // Handle function call results as tool messages + resultJSON, err := json.Marshal(v.Result) + if err != nil { + return openai.UserMessage(""), fmt.Errorf("failed to marshal function call result: %w", err) + } + return openai.ToolMessage(string(resultJSON), v.ID), nil + default: + // Convert unknown types to string representation + content := fmt.Sprintf("%v", v) + if role == "user" { + return openai.UserMessage(content), nil + } else { + return openai.AssistantMessage(content), nil + } + } +} + // Helper structs for ChatResponse interface type openAIChatResponse struct { diff --git a/pkg/sessions/session.go b/pkg/sessions/session.go index cbf35640..9bdb1ff7 100644 --- a/pkg/sessions/session.go +++ b/pkg/sessions/session.go @@ -15,6 +15,7 @@ package sessions import ( + "bufio" "encoding/json" "fmt" "os" @@ -146,11 +147,18 @@ func (s *Session) ChatMessages() []*api.Message { } defer f.Close() - scanner := json.NewDecoder(f) - for scanner.More() { + // Read file line by line instead of using json.Decoder + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := scanner.Text() + if line == "" { + continue + } + var message api.Message - if err := scanner.Decode(&message); err != nil { - continue // skip malformed messages + if err := json.Unmarshal([]byte(line), &message); err != nil { + // Skip malformed messages + continue } messages = append(messages, &message) } diff --git a/pkg/sessions/session_test.go b/pkg/sessions/session_test.go new file mode 100644 index 00000000..c8a10fcf --- /dev/null +++ b/pkg/sessions/session_test.go @@ -0,0 +1,312 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sessions + +import ( + "os" + "testing" + "time" + + "github.com/GoogleCloudPlatform/kubectl-ai/pkg/api" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// setupTestManager creates a temporary directory and SessionManager for testing +func setupTestManager(t *testing.T) (*SessionManager, func()) { + t.Helper() + tempDir, err := os.MkdirTemp("", "session-test-*") + require.NoError(t, err) + + cleanup := func() { + os.RemoveAll(tempDir) + } + + manager := &SessionManager{BasePath: tempDir} + return manager, cleanup +} + +// createTestMetadata returns standard test metadata +func createTestMetadata() Metadata { + return Metadata{ + ProviderID: "test-provider", + ModelID: "test-model", + } +} + +// createTestMessage creates a test message with the given payload +func createTestMessage(payload string) *api.Message { + return &api.Message{ + ID: uuid.New().String(), + Source: api.MessageSourceUser, + Type: api.MessageTypeText, + Payload: payload, + Timestamp: time.Now(), + } +} + +// TestSessionPersistence tests the basic save and load functionality +func TestSessionPersistence(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + // Create a new session + session, err := manager.NewSession(createTestMetadata()) + require.NoError(t, err) + + // Add some test messages + testMessage := createTestMessage("Hello, how can I help?") + err = session.AddChatMessage(testMessage) + require.NoError(t, err) + + // Load the session and verify its contents + loadedSession, err := manager.FindSessionByID(session.ID) + require.NoError(t, err) + + messages := loadedSession.ChatMessages() + require.Equal(t, 1, len(messages)) + assert.Equal(t, testMessage.Payload, messages[0].Payload) + + // Verify metadata + loadedMeta, err := loadedSession.LoadMetadata() + require.NoError(t, err) + expectedMeta := createTestMetadata() + assert.Equal(t, expectedMeta.ProviderID, loadedMeta.ProviderID) + assert.Equal(t, expectedMeta.ModelID, loadedMeta.ModelID) +} + +// TestCreateNewSession tests the creation of a new session +func TestCreateNewSession(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + session, err := manager.NewSession(createTestMetadata()) + require.NoError(t, err) + assert.NotEmpty(t, session.ID) + assert.NotEmpty(t, session.Path) + + // Verify history file is created after adding a message: + testMessage := createTestMessage("Test message") + err = session.AddChatMessage(testMessage) + require.NoError(t, err) + assert.FileExists(t, session.HistoryPath()) + + // Verify session directory and files exist + assert.DirExists(t, session.Path) + assert.FileExists(t, session.MetadataPath()) + assert.FileExists(t, session.HistoryPath()) + + // Verify metadata + loadedMeta, err := session.LoadMetadata() + require.NoError(t, err) + expectedMeta := createTestMetadata() + assert.Equal(t, expectedMeta.ProviderID, loadedMeta.ProviderID) + assert.Equal(t, expectedMeta.ModelID, loadedMeta.ModelID) + assert.False(t, loadedMeta.CreatedAt.IsZero()) + assert.False(t, loadedMeta.LastAccessed.IsZero()) +} + +// TestDeleteSession tests session deletion +func TestDeleteSession(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + // Create a session + session, err := manager.NewSession(createTestMetadata()) + require.NoError(t, err) + + // Delete the session + err = manager.DeleteSession(session.ID) + require.NoError(t, err) + + // Verify session directory is gone + _, err = os.Stat(session.Path) + assert.True(t, os.IsNotExist(err)) + + // Verify session can't be found + _, err = manager.FindSessionByID(session.ID) + assert.Error(t, err) +} + +// TestListSessions tests listing all available sessions +func TestListSessions(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + // Create multiple sessions + for i := 0; i < 3; i++ { + _, err := manager.NewSession(createTestMetadata()) + require.NoError(t, err) + } + + // List sessions + sessions, err := manager.ListSessions() + require.NoError(t, err) + assert.Equal(t, 3, len(sessions)) + + // Verify sessions are sorted by ID (newest first) + for i := 1; i < len(sessions); i++ { + assert.True(t, sessions[i-1].ID > sessions[i].ID) + } +} + +// TestCorruptedMetadata tests handling of corrupted metadata +func TestCorruptedMetadata(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + // Create a session + session, err := manager.NewSession(createTestMetadata()) + require.NoError(t, err) + + // Corrupt the metadata file + err = os.WriteFile(session.MetadataPath(), []byte("corrupted yaml"), 0644) + require.NoError(t, err) + + // Attempt to load metadata + _, err = session.LoadMetadata() + assert.Error(t, err) +} + +// TestCorruptedHistory tests handling of corrupted history file +func TestCorruptedHistory(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + // Create a session + session, err := manager.NewSession(createTestMetadata()) + require.NoError(t, err) + + // Add a valid message + err = session.AddChatMessage(createTestMessage("Valid message")) + require.NoError(t, err) + + // Append corrupted JSON to history file + f, err := os.OpenFile(session.HistoryPath(), os.O_APPEND|os.O_WRONLY, 0644) + require.NoError(t, err) + _, err = f.WriteString("corrupted json\n") + require.NoError(t, err) + f.Close() + + // Verify we can still read valid messages + messages := session.ChatMessages() + assert.Equal(t, 1, len(messages)) + assert.Equal(t, "Valid message", messages[0].Payload) +} + +// TestConcurrentAccess tests concurrent access to a session +func TestConcurrentAccess(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + // Create a session + session, err := manager.NewSession(createTestMetadata()) + require.NoError(t, err) + + // Test concurrent reads and writes + done := make(chan bool) + messageCount := 100 + + for i := 0; i < messageCount; i++ { + go func(i int) { + msg := createTestMessage("Concurrent message") + err := session.AddChatMessage(msg) + assert.NoError(t, err) + done <- true + }(i) + } + + // Wait for all goroutines to finish + for i := 0; i < messageCount; i++ { + <-done + } + + // Verify all messages were written + messages := session.ChatMessages() + assert.Equal(t, messageCount, len(messages)) +} + +// TestClearMessages tests clearing all messages from a session +func TestClearMessages(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + // Create a session + session, err := manager.NewSession(createTestMetadata()) + require.NoError(t, err) + + // Add some messages + for i := 0; i < 3; i++ { + err = session.AddChatMessage(createTestMessage("Test message")) + require.NoError(t, err) + } + + // Verify messages were added + assert.Equal(t, 3, len(session.ChatMessages())) + + // Clear messages + err = session.ClearChatMessages() + require.NoError(t, err) + + // Verify messages were cleared + assert.Empty(t, session.ChatMessages()) +} + +// TestGetLatestSession tests getting the most recent session +func TestGetLatestSession(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + // Create multiple sessions + var lastSession *Session + var err error + for i := 0; i < 3; i++ { + lastSession, err = manager.NewSession(createTestMetadata()) + require.NoError(t, err) + time.Sleep(time.Millisecond) // Ensure different timestamps + } + + // Get latest session + latest, err := manager.GetLatestSession() + require.NoError(t, err) + assert.Equal(t, lastSession.ID, latest.ID) +} + +// TestUpdateLastAccessed tests updating the last accessed timestamp +func TestUpdateLastAccessed(t *testing.T) { + manager, cleanup := setupTestManager(t) + defer cleanup() + + // Create a session + session, err := manager.NewSession(createTestMetadata()) + require.NoError(t, err) + + // Get initial last accessed time + meta, err := session.LoadMetadata() + require.NoError(t, err) + initialAccess := meta.LastAccessed + + time.Sleep(time.Millisecond) // Ensure different timestamp + + // Update last accessed + err = session.UpdateLastAccessed() + require.NoError(t, err) + + // Verify last accessed was updated + meta, err = session.LoadMetadata() + require.NoError(t, err) + assert.True(t, meta.LastAccessed.After(initialAccess)) +}