Skip to content

Commit fd70ebb

Browse files
Fix test race condition
1 parent 19cb17a commit fd70ebb

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

internal/handlers/main_test.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http/httptest"
1010
"slices"
1111
"strings"
12+
"sync"
1213
"testing"
1314

1415
"github.com/MegaGrindStone/go-mcp"
@@ -22,6 +23,7 @@ type mockLLM struct {
2223
}
2324

2425
type mockStore struct {
26+
sync.Mutex
2527
chats []models.Chat
2628
messages map[string][]models.Message
2729
err error
@@ -349,13 +351,20 @@ func (m mockLLM) GenerateTitle(_ context.Context, _ string) (string, error) {
349351
}
350352

351353
func (m *mockStore) Chats(_ context.Context) ([]models.Chat, error) {
354+
m.Lock()
355+
defer m.Unlock()
352356
if m.err != nil {
353357
return nil, m.err
354358
}
355-
return m.chats, nil
359+
// Return a copy to avoid race conditions on the slice
360+
chatsCopy := make([]models.Chat, len(m.chats))
361+
copy(chatsCopy, m.chats)
362+
return chatsCopy, nil
356363
}
357364

358365
func (m *mockStore) AddChat(_ context.Context, chat models.Chat) (string, error) {
366+
m.Lock()
367+
defer m.Unlock()
359368
if m.err != nil {
360369
return "", m.err
361370
}
@@ -364,6 +373,8 @@ func (m *mockStore) AddChat(_ context.Context, chat models.Chat) (string, error)
364373
}
365374

366375
func (m *mockStore) UpdateChat(_ context.Context, chat models.Chat) error {
376+
m.Lock()
377+
defer m.Unlock()
367378
idx := slices.IndexFunc(m.chats, func(c models.Chat) bool { return c.ID == chat.ID })
368379
if idx == -1 {
369380
return fmt.Errorf("chat not found")
@@ -373,13 +384,20 @@ func (m *mockStore) UpdateChat(_ context.Context, chat models.Chat) error {
373384
}
374385

375386
func (m *mockStore) Messages(_ context.Context, chatID string) ([]models.Message, error) {
387+
m.Lock()
388+
defer m.Unlock()
376389
if m.err != nil {
377390
return nil, m.err
378391
}
379-
return m.messages[chatID], nil
392+
// Return a copy to avoid race conditions on the slice
393+
messagesCopy := make([]models.Message, len(m.messages[chatID]))
394+
copy(messagesCopy, m.messages[chatID])
395+
return messagesCopy, nil
380396
}
381397

382398
func (m *mockStore) AddMessage(_ context.Context, chatID string, msg models.Message) (string, error) {
399+
m.Lock()
400+
defer m.Unlock()
383401
if m.err != nil {
384402
return "", m.err
385403
}
@@ -388,5 +406,7 @@ func (m *mockStore) AddMessage(_ context.Context, chatID string, msg models.Mess
388406
}
389407

390408
func (m *mockStore) UpdateMessage(_ context.Context, _ string, _ models.Message) error {
409+
m.Lock()
410+
defer m.Unlock()
391411
return m.err
392412
}

0 commit comments

Comments
 (0)