@@ -11,13 +11,15 @@ import (
1111 "os"
1212 "path/filepath"
1313 "reflect"
14+ "slices"
15+ "strings"
1416 "testing"
1517
1618 "github.com/modelcontextprotocol/go-sdk/mcp"
1719)
1820
19- // getStoreFactories provides test factories for both storage implementations.
20- func getStoreFactories () map [string ]func (t * testing.T ) store {
21+ // stores provides test factories for both storage implementations.
22+ func stores () map [string ]func (t * testing.T ) store {
2123 return map [string ]func (t * testing.T ) store {
2224 "file" : func (t * testing.T ) store {
2325 tempDir , err := os .MkdirTemp ("" , "kb-test-file-*" )
@@ -35,11 +37,9 @@ func getStoreFactories() map[string]func(t *testing.T) store {
3537
3638// TestKnowledgeBaseOperations verifies CRUD operations work correctly.
3739func TestKnowledgeBaseOperations (t * testing.T ) {
38- factories := getStoreFactories ()
39-
40- for name , factory := range factories {
40+ for name , newStore := range stores () {
4141 t .Run (name , func (t * testing.T ) {
42- s := factory (t )
42+ s := newStore (t )
4343 kb := knowledgeBase {s : s }
4444
4545 // Verify empty graph loads correctly
@@ -147,19 +147,16 @@ func TestKnowledgeBaseOperations(t *testing.T) {
147147
148148 // Confirm observation removal
149149 graph , _ = kb .loadGraph ()
150- aliceFound := false
151- for _ , e := range graph .Entities {
152- if e .Name == "Alice" {
153- aliceFound = true
154- for _ , obs := range e .Observations {
155- if obs == "Works as developer" {
156- t .Errorf ("observation 'Works as developer' should have been deleted" )
157- }
158- }
159- }
160- }
161- if ! aliceFound {
150+ aliceIndex := slices .IndexFunc (graph .Entities , func (e Entity ) bool {
151+ return e .Name == "Alice"
152+ })
153+ if aliceIndex == - 1 {
162154 t .Errorf ("entity 'Alice' not found after deleting observation" )
155+ } else {
156+ alice := graph .Entities [aliceIndex ]
157+ if slices .Contains (alice .Observations , "Works as developer" ) {
158+ t .Errorf ("observation 'Works as developer' should have been deleted" )
159+ }
163160 }
164161
165162 // Remove relations
@@ -191,11 +188,9 @@ func TestKnowledgeBaseOperations(t *testing.T) {
191188
192189// TestSaveAndLoadGraph ensures data persists correctly across save/load cycles.
193190func TestSaveAndLoadGraph (t * testing.T ) {
194- factories := getStoreFactories ()
195-
196- for name , factory := range factories {
191+ for name , newStore := range stores () {
197192 t .Run (name , func (t * testing.T ) {
198- s := factory (t )
193+ s := newStore (t )
199194 kb := knowledgeBase {s : s }
200195
201196 // Setup test data
@@ -251,11 +246,9 @@ func TestSaveAndLoadGraph(t *testing.T) {
251246
252247// TestDuplicateEntitiesAndRelations verifies duplicate prevention logic.
253248func TestDuplicateEntitiesAndRelations (t * testing.T ) {
254- factories := getStoreFactories ()
255-
256- for name , factory := range factories {
249+ for name , newStore := range stores () {
257250 t .Run (name , func (t * testing.T ) {
258- s := factory (t )
251+ s := newStore (t )
259252 kb := knowledgeBase {s : s }
260253
261254 // Setup initial state
@@ -355,10 +348,9 @@ func TestErrorHandling(t *testing.T) {
355348 }
356349 })
357350
358- factories := getStoreFactories ()
359- for name , factory := range factories {
351+ for name , newStore := range stores () {
360352 t .Run (fmt .Sprintf ("AddObservationToNonExistentEntity_%s" , name ), func (t * testing.T ) {
361- s := factory (t )
353+ s := newStore (t )
362354 kb := knowledgeBase {s : s }
363355
364356 // Setup valid entity for comparison
@@ -385,11 +377,9 @@ func TestErrorHandling(t *testing.T) {
385377
386378// TestFileFormatting verifies the JSON storage format structure.
387379func TestFileFormatting (t * testing.T ) {
388- factories := getStoreFactories ()
389-
390- for name , factory := range factories {
380+ for name , newStore := range stores () {
391381 t .Run (name , func (t * testing.T ) {
392- s := factory (t )
382+ s := newStore (t )
393383 kb := knowledgeBase {s : s }
394384
395385 // Setup test entity
@@ -438,11 +428,9 @@ func TestFileFormatting(t *testing.T) {
438428
439429// TestMCPServerIntegration tests the knowledge base through MCP server layer.
440430func TestMCPServerIntegration (t * testing.T ) {
441- factories := getStoreFactories ()
442-
443- for name , factory := range factories {
431+ for name , newStore := range stores () {
444432 t .Run (name , func (t * testing.T ) {
445- s := factory (t )
433+ s := newStore (t )
446434 kb := knowledgeBase {s : s }
447435
448436 // Create mock server session
@@ -639,11 +627,9 @@ func TestMCPServerIntegration(t *testing.T) {
639627
640628// TestMCPErrorHandling tests error scenarios through MCP layer.
641629func TestMCPErrorHandling (t * testing.T ) {
642- factories := getStoreFactories ()
643-
644- for name , factory := range factories {
630+ for name , newStore := range stores () {
645631 t .Run (name , func (t * testing.T ) {
646- s := factory (t )
632+ s := newStore (t )
647633 kb := knowledgeBase {s : s }
648634
649635 ctx := context .Background ()
@@ -661,15 +647,15 @@ func TestMCPErrorHandling(t *testing.T) {
661647 },
662648 }
663649
664- obsResult , err := kb .AddObservations (ctx , serverSession , addObsParams )
665- if err != nil {
666- t .Fatalf ("MCP AddObservations call failed: %v" , err )
667- }
668- if ! obsResult .IsError {
650+ _ , err := kb .AddObservations (ctx , serverSession , addObsParams )
651+ if err == nil {
669652 t .Errorf ("expected MCP AddObservations to return error for non-existent entity" )
670- }
671- if len (obsResult .Content ) == 0 {
672- t .Errorf ("expected error content in MCP response" )
653+ } else {
654+ // Verify the error message contains expected text
655+ expectedErrorMsg := "entity with name NonExistentEntity not found"
656+ if ! strings .Contains (err .Error (), expectedErrorMsg ) {
657+ t .Errorf ("expected error message to contain '%s', got: %v" , expectedErrorMsg , err )
658+ }
673659 }
674660 })
675661 }
0 commit comments