From e652a7e26410100ea9bdbd9bcfbbbfdec1567c6a Mon Sep 17 00:00:00 2001 From: MegaGrindStone Date: Thu, 3 Jul 2025 03:54:28 +0700 Subject: [PATCH 1/7] mcp: memory server example --- examples/memory/kb.go | 689 +++++++++++++++++++++++++++++++++++ examples/memory/kb_test.go | 717 +++++++++++++++++++++++++++++++++++++ examples/memory/main.go | 162 +++++++++ 3 files changed, 1568 insertions(+) create mode 100644 examples/memory/kb.go create mode 100644 examples/memory/kb_test.go create mode 100644 examples/memory/main.go diff --git a/examples/memory/kb.go b/examples/memory/kb.go new file mode 100644 index 00000000..b6a8c4dd --- /dev/null +++ b/examples/memory/kb.go @@ -0,0 +1,689 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "slices" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// store provides persistence interface for knowledge base data. +type store interface { + Read() ([]byte, error) + Write(data []byte) error +} + +// memoryStore implements in-memory storage that doesn't persist across restarts. +type memoryStore struct { + data []byte +} + +// Read returns the in-memory data. +func (ms *memoryStore) Read() ([]byte, error) { + return ms.data, nil +} + +// Write stores data in memory. +func (ms *memoryStore) Write(data []byte) error { + ms.data = data + return nil +} + +// fileStore implements file-based storage for persistent knowledge base. +type fileStore struct { + path string +} + +// Read loads data from file, returning empty slice if file doesn't exist. +func (fs *fileStore) Read() ([]byte, error) { + data, err := os.ReadFile(fs.path) + if err != nil { + if os.IsNotExist(err) { + return []byte{}, nil + } + return nil, fmt.Errorf("failed to read file %s: %w", fs.path, err) + } + return data, nil +} + +// Write saves data to file with 0600 permissions. +func (fs *fileStore) Write(data []byte) error { + if err := os.WriteFile(fs.path, data, 0600); err != nil { + return fmt.Errorf("failed to write file %s: %w", fs.path, err) + } + return nil +} + +// knowledgeBase manages entities and relations with persistent storage. +type knowledgeBase struct { + s store +} + +// kbItem represents a single item in persistent storage (entity or relation). +type kbItem struct { + Type string `json:"type"` + + // Entity fields (when Type == "entity") + Name string `json:"name,omitempty"` + EntityType string `json:"entityType,omitempty"` + Observations []string `json:"observations,omitempty"` + + // Relation fields (when Type == "relation") + From string `json:"from,omitempty"` + To string `json:"to,omitempty"` + RelationType string `json:"relationType,omitempty"` +} + +// loadGraph deserializes the knowledge graph from storage. +func (k knowledgeBase) loadGraph() (KnowledgeGraph, error) { + data, err := k.s.Read() + if err != nil { + return KnowledgeGraph{}, fmt.Errorf("failed to read from store: %w", err) + } + + if len(data) == 0 { + return KnowledgeGraph{}, nil + } + + var items []kbItem + if err := json.Unmarshal(data, &items); err != nil { + return KnowledgeGraph{}, fmt.Errorf("failed to unmarshal from store: %w", err) + } + + graph := KnowledgeGraph{ + Entities: []Entity{}, + Relations: []Relation{}, + } + + for _, item := range items { + switch item.Type { + case "entity": + graph.Entities = append(graph.Entities, Entity{ + Name: item.Name, + EntityType: item.EntityType, + Observations: item.Observations, + }) + case "relation": + graph.Relations = append(graph.Relations, Relation{ + From: item.From, + To: item.To, + RelationType: item.RelationType, + }) + } + } + + return graph, nil +} + +// saveGraph serializes and persists the knowledge graph to storage. +func (k knowledgeBase) saveGraph(graph KnowledgeGraph) error { + items := make([]kbItem, 0, len(graph.Entities)+len(graph.Relations)) + + for _, entity := range graph.Entities { + items = append(items, kbItem{ + Type: "entity", + Name: entity.Name, + EntityType: entity.EntityType, + Observations: entity.Observations, + }) + } + + for _, relation := range graph.Relations { + items = append(items, kbItem{ + Type: "relation", + From: relation.From, + To: relation.To, + RelationType: relation.RelationType, + }) + } + + itemsJSON, err := json.Marshal(items) + if err != nil { + return fmt.Errorf("failed to marshal items: %w", err) + } + + if err := k.s.Write(itemsJSON); err != nil { + return fmt.Errorf("failed to write to store: %w", err) + } + return nil +} + +// createEntities adds new entities to the graph, skipping duplicates by name. +func (k knowledgeBase) createEntities(entities []Entity) ([]Entity, error) { + graph, err := k.loadGraph() + if err != nil { + return nil, err + } + + var newEntities []Entity + for _, entity := range entities { + exists := false + for _, existingEntity := range graph.Entities { + if existingEntity.Name == entity.Name { + exists = true + break + } + } + + if !exists { + newEntities = append(newEntities, entity) + graph.Entities = append(graph.Entities, entity) + } + } + + if err := k.saveGraph(graph); err != nil { + return nil, err + } + + return newEntities, nil +} + +// createRelations adds new relations to the graph, skipping exact duplicates. +func (k knowledgeBase) createRelations(relations []Relation) ([]Relation, error) { + graph, err := k.loadGraph() + if err != nil { + return nil, err + } + + var newRelations []Relation + for _, relation := range relations { + exists := false + for _, existingRelation := range graph.Relations { + if existingRelation.From == relation.From && + existingRelation.To == relation.To && + existingRelation.RelationType == relation.RelationType { + exists = true + break + } + } + + if !exists { + newRelations = append(newRelations, relation) + graph.Relations = append(graph.Relations, relation) + } + } + + if err := k.saveGraph(graph); err != nil { + return nil, err + } + + return newRelations, nil +} + +// addObservations appends new observations to existing entities. +func (k knowledgeBase) addObservations(observations []Observation) ([]Observation, error) { + graph, err := k.loadGraph() + if err != nil { + return nil, err + } + + var results []Observation + + for _, obs := range observations { + entityIndex := -1 + for i, entity := range graph.Entities { + if entity.Name == obs.EntityName { + entityIndex = i + break + } + } + + if entityIndex == -1 { + return nil, fmt.Errorf("entity with name %s not found", obs.EntityName) + } + + var newObservations []string + for _, content := range obs.Contents { + exists := slices.Contains(graph.Entities[entityIndex].Observations, content) + + if !exists { + newObservations = append(newObservations, content) + graph.Entities[entityIndex].Observations = append(graph.Entities[entityIndex].Observations, content) + } + } + + results = append(results, Observation{ + EntityName: obs.EntityName, + Contents: newObservations, + }) + } + + if err := k.saveGraph(graph); err != nil { + return nil, err + } + + return results, nil +} + +// deleteEntities removes entities and their associated relations. +func (k knowledgeBase) deleteEntities(entityNames []string) error { + graph, err := k.loadGraph() + if err != nil { + return err + } + + // Create map for quick lookup + entitiesToDelete := make(map[string]bool) + for _, name := range entityNames { + entitiesToDelete[name] = true + } + + // Filter entities + var filteredEntities []Entity + for _, entity := range graph.Entities { + if !entitiesToDelete[entity.Name] { + filteredEntities = append(filteredEntities, entity) + } + } + graph.Entities = filteredEntities + + // Filter relations + var filteredRelations []Relation + for _, relation := range graph.Relations { + if !entitiesToDelete[relation.From] && !entitiesToDelete[relation.To] { + filteredRelations = append(filteredRelations, relation) + } + } + graph.Relations = filteredRelations + + return k.saveGraph(graph) +} + +// deleteObservations removes specific observations from entities. +func (k knowledgeBase) deleteObservations(deletions []Observation) error { + graph, err := k.loadGraph() + if err != nil { + return err + } + + for _, deletion := range deletions { + for i, entity := range graph.Entities { + if entity.Name == deletion.EntityName { + // Create a map for quick lookup + observationsToDelete := make(map[string]bool) + for _, observation := range deletion.Observations { + observationsToDelete[observation] = true + } + + // Filter observations + var filteredObservations []string + for _, observation := range entity.Observations { + if !observationsToDelete[observation] { + filteredObservations = append(filteredObservations, observation) + } + } + + graph.Entities[i].Observations = filteredObservations + break + } + } + } + + return k.saveGraph(graph) +} + +// deleteRelations removes specific relations from the graph. +func (k knowledgeBase) deleteRelations(relations []Relation) error { + graph, err := k.loadGraph() + if err != nil { + return err + } + + var filteredRelations []Relation + for _, existingRelation := range graph.Relations { + shouldKeep := true + + for _, relationToDelete := range relations { + if existingRelation.From == relationToDelete.From && + existingRelation.To == relationToDelete.To && + existingRelation.RelationType == relationToDelete.RelationType { + shouldKeep = false + break + } + } + + if shouldKeep { + filteredRelations = append(filteredRelations, existingRelation) + } + } + + graph.Relations = filteredRelations + return k.saveGraph(graph) +} + +// readGraph returns the complete knowledge graph. +func (k knowledgeBase) readGraph() (KnowledgeGraph, error) { + return k.loadGraph() +} + +// searchNodes filters entities and relations matching the query string. +func (k knowledgeBase) searchNodes(query string) (KnowledgeGraph, error) { + graph, err := k.loadGraph() + if err != nil { + return KnowledgeGraph{}, err + } + + queryLower := strings.ToLower(query) + var filteredEntities []Entity + + // Filter entities + for _, entity := range graph.Entities { + if strings.Contains(strings.ToLower(entity.Name), queryLower) || + strings.Contains(strings.ToLower(entity.EntityType), queryLower) { + filteredEntities = append(filteredEntities, entity) + continue + } + + // Check observations + for _, observation := range entity.Observations { + if strings.Contains(strings.ToLower(observation), queryLower) { + filteredEntities = append(filteredEntities, entity) + break + } + } + } + + // Create map for quick entity lookup + filteredEntityNames := make(map[string]bool) + for _, entity := range filteredEntities { + filteredEntityNames[entity.Name] = true + } + + // Filter relations + var filteredRelations []Relation + for _, relation := range graph.Relations { + if filteredEntityNames[relation.From] && filteredEntityNames[relation.To] { + filteredRelations = append(filteredRelations, relation) + } + } + + return KnowledgeGraph{ + Entities: filteredEntities, + Relations: filteredRelations, + }, nil +} + +// openNodes returns entities with specified names and their interconnecting relations. +func (k knowledgeBase) openNodes(names []string) (KnowledgeGraph, error) { + graph, err := k.loadGraph() + if err != nil { + return KnowledgeGraph{}, err + } + + // Create map for quick name lookup + nameSet := make(map[string]bool) + for _, name := range names { + nameSet[name] = true + } + + // Filter entities + var filteredEntities []Entity + for _, entity := range graph.Entities { + if nameSet[entity.Name] { + filteredEntities = append(filteredEntities, entity) + } + } + + // Create map for quick entity lookup + filteredEntityNames := make(map[string]bool) + for _, entity := range filteredEntities { + filteredEntityNames[entity.Name] = true + } + + // Filter relations + var filteredRelations []Relation + for _, relation := range graph.Relations { + if filteredEntityNames[relation.From] && filteredEntityNames[relation.To] { + filteredRelations = append(filteredRelations, relation) + } + } + + return KnowledgeGraph{ + Entities: filteredEntities, + Relations: filteredRelations, + }, nil +} + +func (k knowledgeBase) CreateEntities(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[CreateEntitiesArgs]) (*mcp.CallToolResultFor[CreateEntitiesResult], error) { + var res mcp.CallToolResultFor[CreateEntitiesResult] + + entities, err := k.createEntities(params.Arguments.Entities) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + // I think marshalling the entities and pass it as a content should not be necessary, but as for now, it looks like + // the StructuredContent is not being unmarshalled in CallToolResultFor. + entitiesJSON, err := json.Marshal(entities) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + res.Content = []mcp.Content{ + &mcp.TextContent{Text: string(entitiesJSON)}, + } + + res.StructuredContent = CreateEntitiesResult{ + Entities: entities, + } + + return &res, nil +} + +func (k knowledgeBase) CreateRelations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[CreateRelationsArgs]) (*mcp.CallToolResultFor[CreateRelationsResult], error) { + var res mcp.CallToolResultFor[CreateRelationsResult] + + relations, err := k.createRelations(params.Arguments.Relations) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + relationsJSON, err := json.Marshal(relations) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + res.Content = []mcp.Content{ + &mcp.TextContent{Text: string(relationsJSON)}, + } + + res.StructuredContent = CreateRelationsResult{ + Relations: relations, + } + + return &res, nil +} + +func (k knowledgeBase) AddObservations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[AddObservationsArgs]) (*mcp.CallToolResultFor[AddObservationsResult], error) { + var res mcp.CallToolResultFor[AddObservationsResult] + + observations, err := k.addObservations(params.Arguments.Observations) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + observationsJSON, err := json.Marshal(observations) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + res.Content = []mcp.Content{ + &mcp.TextContent{Text: string(observationsJSON)}, + } + + res.StructuredContent = AddObservationsResult{ + Observations: observations, + } + + return &res, nil +} + +func (k knowledgeBase) DeleteEntities(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[DeleteEntitiesArgs]) (*mcp.CallToolResultFor[struct{}], error) { + var res mcp.CallToolResultFor[struct{}] + + err := k.deleteEntities(params.Arguments.EntityNames) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Entities deleted successfully"}, + } + + return &res, nil +} + +func (k knowledgeBase) DeleteObservations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[DeleteObservationsArgs]) (*mcp.CallToolResultFor[struct{}], error) { + var res mcp.CallToolResultFor[struct{}] + + err := k.deleteObservations(params.Arguments.Deletions) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Observations deleted successfully"}, + } + + return &res, nil +} + +func (k knowledgeBase) DeleteRelations(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[DeleteRelationsArgs]) (*mcp.CallToolResultFor[struct{}], error) { + var res mcp.CallToolResultFor[struct{}] + + err := k.deleteRelations(params.Arguments.Relations) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + res.Content = []mcp.Content{ + &mcp.TextContent{Text: "Relations deleted successfully"}, + } + + return &res, nil +} + +func (k knowledgeBase) ReadGraph(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[struct{}]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { + var res mcp.CallToolResultFor[KnowledgeGraph] + + graph, err := k.readGraph() + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + graphJSON, err := json.Marshal(graph) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + res.Content = []mcp.Content{ + &mcp.TextContent{Text: string(graphJSON)}, + } + + res.StructuredContent = graph + return &res, nil +} + +func (k knowledgeBase) SearchNodes(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[SearchNodesArgs]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { + var res mcp.CallToolResultFor[KnowledgeGraph] + + graph, err := k.searchNodes(params.Arguments.Query) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + graphJSON, err := json.Marshal(graph) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + res.Content = []mcp.Content{ + &mcp.TextContent{Text: string(graphJSON)}, + } + + res.StructuredContent = graph + return &res, nil +} + +func (k knowledgeBase) OpenNodes(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[OpenNodesArgs]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { + var res mcp.CallToolResultFor[KnowledgeGraph] + + graph, err := k.openNodes(params.Arguments.Names) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + + graphJSON, err := json.Marshal(graph) + if err != nil { + res.IsError = true + res.Content = []mcp.Content{ + &mcp.TextContent{Text: err.Error()}, + } + return &res, nil + } + res.Content = []mcp.Content{ + &mcp.TextContent{Text: string(graphJSON)}, + } + + res.StructuredContent = graph + return &res, nil +} diff --git a/examples/memory/kb_test.go b/examples/memory/kb_test.go new file mode 100644 index 00000000..ddd8bcc5 --- /dev/null +++ b/examples/memory/kb_test.go @@ -0,0 +1,717 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// getStoreFactories provides test factories for both storage implementations. +func getStoreFactories() map[string]func(t *testing.T) store { + return map[string]func(t *testing.T) store{ + "file": func(t *testing.T) store { + tempDir, err := os.MkdirTemp("", "kb-test-file-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + t.Cleanup(func() { os.RemoveAll(tempDir) }) + return &fileStore{path: filepath.Join(tempDir, "test-memory.json")} + }, + "memory": func(t *testing.T) store { + return &memoryStore{} + }, + } +} + +// TestKnowledgeBaseOperations verifies CRUD operations work correctly. +func TestKnowledgeBaseOperations(t *testing.T) { + factories := getStoreFactories() + + for name, factory := range factories { + t.Run(name, func(t *testing.T) { + s := factory(t) + kb := knowledgeBase{s: s} + + // Verify empty graph loads correctly + graph, err := kb.loadGraph() + if err != nil { + t.Fatalf("failed to load empty graph: %v", err) + } + if len(graph.Entities) != 0 || len(graph.Relations) != 0 { + t.Errorf("expected empty graph, got %+v", graph) + } + + // Create and verify entities + testEntities := []Entity{ + { + Name: "Alice", + EntityType: "Person", + Observations: []string{"Likes coffee"}, + }, + { + Name: "Bob", + EntityType: "Person", + Observations: []string{"Likes tea"}, + }, + } + + createdEntities, err := kb.createEntities(testEntities) + if err != nil { + t.Fatalf("failed to create entities: %v", err) + } + if len(createdEntities) != 2 { + t.Errorf("expected 2 created entities, got %d", len(createdEntities)) + } + + // Verify entities persist + graph, err = kb.readGraph() + if err != nil { + t.Fatalf("failed to read graph: %v", err) + } + if len(graph.Entities) != 2 { + t.Errorf("expected 2 entities, got %d", len(graph.Entities)) + } + + // Create and verify relations + testRelations := []Relation{ + { + From: "Alice", + To: "Bob", + RelationType: "friend", + }, + } + + createdRelations, err := kb.createRelations(testRelations) + if err != nil { + t.Fatalf("failed to create relations: %v", err) + } + if len(createdRelations) != 1 { + t.Errorf("expected 1 created relation, got %d", len(createdRelations)) + } + + // Add observations to entities + testObservations := []Observation{ + { + EntityName: "Alice", + Contents: []string{"Works as developer", "Lives in New York"}, + }, + } + + addedObservations, err := kb.addObservations(testObservations) + if err != nil { + t.Fatalf("failed to add observations: %v", err) + } + if len(addedObservations) != 1 || len(addedObservations[0].Contents) != 2 { + t.Errorf("expected 1 observation with 2 contents, got %+v", addedObservations) + } + + // Search nodes by content + searchResult, err := kb.searchNodes("developer") + if err != nil { + t.Fatalf("failed to search nodes: %v", err) + } + if len(searchResult.Entities) != 1 || searchResult.Entities[0].Name != "Alice" { + t.Errorf("expected to find Alice when searching for 'developer', got %+v", searchResult) + } + + // Retrieve specific nodes + openResult, err := kb.openNodes([]string{"Bob"}) + if err != nil { + t.Fatalf("failed to open nodes: %v", err) + } + if len(openResult.Entities) != 1 || openResult.Entities[0].Name != "Bob" { + t.Errorf("expected to find Bob when opening 'Bob', got %+v", openResult) + } + + // Remove specific observations + deleteObs := []Observation{ + { + EntityName: "Alice", + Observations: []string{"Works as developer"}, + }, + } + err = kb.deleteObservations(deleteObs) + if err != nil { + t.Fatalf("failed to delete observations: %v", err) + } + + // Confirm observation removal + graph, _ = kb.readGraph() + aliceFound := false + for _, e := range graph.Entities { + if e.Name == "Alice" { + aliceFound = true + for _, obs := range e.Observations { + if obs == "Works as developer" { + t.Errorf("observation 'Works as developer' should have been deleted") + } + } + } + } + if !aliceFound { + t.Errorf("entity 'Alice' not found after deleting observation") + } + + // Remove relations + err = kb.deleteRelations(testRelations) + if err != nil { + t.Fatalf("failed to delete relations: %v", err) + } + + // Confirm relation removal + graph, _ = kb.readGraph() + if len(graph.Relations) != 0 { + t.Errorf("expected 0 relations after deletion, got %d", len(graph.Relations)) + } + + // Remove entities + err = kb.deleteEntities([]string{"Alice"}) + if err != nil { + t.Fatalf("failed to delete entities: %v", err) + } + + // Confirm entity removal + graph, _ = kb.readGraph() + if len(graph.Entities) != 1 || graph.Entities[0].Name != "Bob" { + t.Errorf("expected only Bob to remain after deleting Alice, got %+v", graph.Entities) + } + }) + } +} + +// TestSaveAndLoadGraph ensures data persists correctly across save/load cycles. +func TestSaveAndLoadGraph(t *testing.T) { + factories := getStoreFactories() + + for name, factory := range factories { + t.Run(name, func(t *testing.T) { + s := factory(t) + kb := knowledgeBase{s: s} + + // Setup test data + testGraph := KnowledgeGraph{ + Entities: []Entity{ + { + Name: "Charlie", + EntityType: "Person", + Observations: []string{"Likes hiking"}, + }, + }, + Relations: []Relation{ + { + From: "Charlie", + To: "Mountains", + RelationType: "enjoys", + }, + }, + } + + // Persist to storage + err := kb.saveGraph(testGraph) + if err != nil { + t.Fatalf("failed to save graph: %v", err) + } + + // Reload from storage + loadedGraph, err := kb.loadGraph() + if err != nil { + t.Fatalf("failed to load graph: %v", err) + } + + // Verify data integrity + if !reflect.DeepEqual(testGraph, loadedGraph) { + t.Errorf("loaded graph does not match saved graph.\nExpected: %+v\nGot: %+v", testGraph, loadedGraph) + } + + // Test malformed data handling + if fs, ok := s.(*fileStore); ok { + err := os.WriteFile(fs.path, []byte("invalid json"), 0600) + if err != nil { + t.Fatalf("failed to write invalid json: %v", err) + } + + _, err = kb.loadGraph() + if err == nil { + t.Errorf("expected error when loading invalid JSON, got nil") + } + } + }) + } +} + +// TestDuplicateEntitiesAndRelations verifies duplicate prevention logic. +func TestDuplicateEntitiesAndRelations(t *testing.T) { + factories := getStoreFactories() + + for name, factory := range factories { + t.Run(name, func(t *testing.T) { + s := factory(t) + kb := knowledgeBase{s: s} + + // Setup initial state + initialEntities := []Entity{ + { + Name: "Dave", + EntityType: "Person", + Observations: []string{"Plays guitar"}, + }, + } + + _, err := kb.createEntities(initialEntities) + if err != nil { + t.Fatalf("failed to create initial entities: %v", err) + } + + // Attempt duplicate creation + duplicateEntities := []Entity{ + { + Name: "Dave", + EntityType: "Person", + Observations: []string{"Sings well"}, + }, + { + Name: "Eve", + EntityType: "Person", + Observations: []string{"Plays piano"}, + }, + } + + newEntities, err := kb.createEntities(duplicateEntities) + if err != nil { + t.Fatalf("failed when adding duplicate entities: %v", err) + } + + // Verify only new entities created + if len(newEntities) != 1 || newEntities[0].Name != "Eve" { + t.Errorf("expected only 'Eve' to be created, got %+v", newEntities) + } + + // Setup initial relation + initialRelation := []Relation{ + { + From: "Dave", + To: "Eve", + RelationType: "friend", + }, + } + + _, err = kb.createRelations(initialRelation) + if err != nil { + t.Fatalf("failed to create initial relation: %v", err) + } + + // Test relation deduplication + duplicateRelations := []Relation{ + { + From: "Dave", + To: "Eve", + RelationType: "friend", + }, + { + From: "Eve", + To: "Dave", + RelationType: "friend", + }, + } + + newRelations, err := kb.createRelations(duplicateRelations) + if err != nil { + t.Fatalf("failed when adding duplicate relations: %v", err) + } + + // Verify only new relations created + if len(newRelations) != 1 || newRelations[0].From != "Eve" || newRelations[0].To != "Dave" { + t.Errorf("expected only 'Eve->Dave' relation to be created, got %+v", newRelations) + } + }) + } +} + +// TestErrorHandling verifies proper error responses for invalid operations. +func TestErrorHandling(t *testing.T) { + t.Run("FileStoreWriteError", func(t *testing.T) { + // Test file write to invalid path + kb := knowledgeBase{ + s: &fileStore{path: filepath.Join("nonexistent", "directory", "file.json")}, + } + + testEntities := []Entity{ + {Name: "TestEntity"}, + } + + _, err := kb.createEntities(testEntities) + if err == nil { + t.Errorf("expected error when writing to non-existent directory, got nil") + } + }) + + factories := getStoreFactories() + for name, factory := range factories { + t.Run(fmt.Sprintf("AddObservationToNonExistentEntity_%s", name), func(t *testing.T) { + s := factory(t) + kb := knowledgeBase{s: s} + + // Setup valid entity for comparison + _, err := kb.createEntities([]Entity{{Name: "RealEntity"}}) + if err != nil { + t.Fatalf("failed to create test entity: %v", err) + } + + // Test invalid entity reference + nonExistentObs := []Observation{ + { + EntityName: "NonExistentEntity", + Contents: []string{"This shouldn't work"}, + }, + } + + _, err = kb.addObservations(nonExistentObs) + if err == nil { + t.Errorf("expected error when adding observations to non-existent entity, got nil") + } + }) + } +} + +// TestFileFormatting verifies the JSON storage format structure. +func TestFileFormatting(t *testing.T) { + factories := getStoreFactories() + + for name, factory := range factories { + t.Run(name, func(t *testing.T) { + s := factory(t) + kb := knowledgeBase{s: s} + + // Setup test entity + testEntities := []Entity{ + { + Name: "FileTest", + EntityType: "TestEntity", + Observations: []string{"Test observation"}, + }, + } + + _, err := kb.createEntities(testEntities) + if err != nil { + t.Fatalf("failed to create test entity: %v", err) + } + + // Extract raw storage data + data, err := s.Read() + if err != nil { + t.Fatalf("failed to read from store: %v", err) + } + + // Validate JSON format + var items []kbItem + err = json.Unmarshal(data, &items) + if err != nil { + t.Fatalf("failed to parse store data JSON: %v", err) + } + + // Check data structure + if len(items) != 1 { + t.Fatalf("expected 1 item in memory file, got %d", len(items)) + } + + item := items[0] + if item.Type != "entity" || + item.Name != "FileTest" || + item.EntityType != "TestEntity" || + len(item.Observations) != 1 || + item.Observations[0] != "Test observation" { + t.Errorf("store item format incorrect: %+v", item) + } + }) + } +} + +// TestMCPServerIntegration tests the knowledge base through MCP server layer. +func TestMCPServerIntegration(t *testing.T) { + factories := getStoreFactories() + + for name, factory := range factories { + t.Run(name, func(t *testing.T) { + s := factory(t) + kb := knowledgeBase{s: s} + + // Create mock server session + ctx := context.Background() + serverSession := &mcp.ServerSession{} + + // Test CreateEntities through MCP + createEntitiesParams := &mcp.CallToolParamsFor[CreateEntitiesArgs]{ + Arguments: CreateEntitiesArgs{ + Entities: []Entity{ + { + Name: "TestPerson", + EntityType: "Person", + Observations: []string{"Likes testing"}, + }, + }, + }, + } + + createResult, err := kb.CreateEntities(ctx, serverSession, createEntitiesParams) + if err != nil { + t.Fatalf("MCP CreateEntities failed: %v", err) + } + if createResult.IsError { + t.Fatalf("MCP CreateEntities returned error: %v", createResult.Content) + } + if len(createResult.StructuredContent.Entities) != 1 { + t.Errorf("expected 1 entity created, got %d", len(createResult.StructuredContent.Entities)) + } + + // Test ReadGraph through MCP + readParams := &mcp.CallToolParamsFor[struct{}]{} + readResult, err := kb.ReadGraph(ctx, serverSession, readParams) + if err != nil { + t.Fatalf("MCP ReadGraph failed: %v", err) + } + if readResult.IsError { + t.Fatalf("MCP ReadGraph returned error: %v", readResult.Content) + } + if len(readResult.StructuredContent.Entities) != 1 { + t.Errorf("expected 1 entity in graph, got %d", len(readResult.StructuredContent.Entities)) + } + + // Test CreateRelations through MCP + createRelationsParams := &mcp.CallToolParamsFor[CreateRelationsArgs]{ + Arguments: CreateRelationsArgs{ + Relations: []Relation{ + { + From: "TestPerson", + To: "Testing", + RelationType: "likes", + }, + }, + }, + } + + relationsResult, err := kb.CreateRelations(ctx, serverSession, createRelationsParams) + if err != nil { + t.Fatalf("MCP CreateRelations failed: %v", err) + } + if relationsResult.IsError { + t.Fatalf("MCP CreateRelations returned error: %v", relationsResult.Content) + } + if len(relationsResult.StructuredContent.Relations) != 1 { + t.Errorf("expected 1 relation created, got %d", len(relationsResult.StructuredContent.Relations)) + } + + // Test AddObservations through MCP + addObsParams := &mcp.CallToolParamsFor[AddObservationsArgs]{ + Arguments: AddObservationsArgs{ + Observations: []Observation{ + { + EntityName: "TestPerson", + Contents: []string{"Works remotely", "Drinks coffee"}, + }, + }, + }, + } + + obsResult, err := kb.AddObservations(ctx, serverSession, addObsParams) + if err != nil { + t.Fatalf("MCP AddObservations failed: %v", err) + } + if obsResult.IsError { + t.Fatalf("MCP AddObservations returned error: %v", obsResult.Content) + } + if len(obsResult.StructuredContent.Observations) != 1 { + t.Errorf("expected 1 observation result, got %d", len(obsResult.StructuredContent.Observations)) + } + + // Test SearchNodes through MCP + searchParams := &mcp.CallToolParamsFor[SearchNodesArgs]{ + Arguments: SearchNodesArgs{ + Query: "coffee", + }, + } + + searchResult, err := kb.SearchNodes(ctx, serverSession, searchParams) + if err != nil { + t.Fatalf("MCP SearchNodes failed: %v", err) + } + if searchResult.IsError { + t.Fatalf("MCP SearchNodes returned error: %v", searchResult.Content) + } + if len(searchResult.StructuredContent.Entities) != 1 { + t.Errorf("expected 1 entity from search, got %d", len(searchResult.StructuredContent.Entities)) + } + + // Test OpenNodes through MCP + openParams := &mcp.CallToolParamsFor[OpenNodesArgs]{ + Arguments: OpenNodesArgs{ + Names: []string{"TestPerson"}, + }, + } + + openResult, err := kb.OpenNodes(ctx, serverSession, openParams) + if err != nil { + t.Fatalf("MCP OpenNodes failed: %v", err) + } + if openResult.IsError { + t.Fatalf("MCP OpenNodes returned error: %v", openResult.Content) + } + if len(openResult.StructuredContent.Entities) != 1 { + t.Errorf("expected 1 entity from open, got %d", len(openResult.StructuredContent.Entities)) + } + + // Test DeleteObservations through MCP + deleteObsParams := &mcp.CallToolParamsFor[DeleteObservationsArgs]{ + Arguments: DeleteObservationsArgs{ + Deletions: []Observation{ + { + EntityName: "TestPerson", + Observations: []string{"Works remotely"}, + }, + }, + }, + } + + deleteObsResult, err := kb.DeleteObservations(ctx, serverSession, deleteObsParams) + if err != nil { + t.Fatalf("MCP DeleteObservations failed: %v", err) + } + if deleteObsResult.IsError { + t.Fatalf("MCP DeleteObservations returned error: %v", deleteObsResult.Content) + } + + // Test DeleteRelations through MCP + deleteRelParams := &mcp.CallToolParamsFor[DeleteRelationsArgs]{ + Arguments: DeleteRelationsArgs{ + Relations: []Relation{ + { + From: "TestPerson", + To: "Testing", + RelationType: "likes", + }, + }, + }, + } + + deleteRelResult, err := kb.DeleteRelations(ctx, serverSession, deleteRelParams) + if err != nil { + t.Fatalf("MCP DeleteRelations failed: %v", err) + } + if deleteRelResult.IsError { + t.Fatalf("MCP DeleteRelations returned error: %v", deleteRelResult.Content) + } + + // Test DeleteEntities through MCP + deleteEntParams := &mcp.CallToolParamsFor[DeleteEntitiesArgs]{ + Arguments: DeleteEntitiesArgs{ + EntityNames: []string{"TestPerson"}, + }, + } + + deleteEntResult, err := kb.DeleteEntities(ctx, serverSession, deleteEntParams) + if err != nil { + t.Fatalf("MCP DeleteEntities failed: %v", err) + } + if deleteEntResult.IsError { + t.Fatalf("MCP DeleteEntities returned error: %v", deleteEntResult.Content) + } + + // Verify final state + finalRead, err := kb.ReadGraph(ctx, serverSession, readParams) + if err != nil { + t.Fatalf("Final MCP ReadGraph failed: %v", err) + } + if len(finalRead.StructuredContent.Entities) != 0 { + t.Errorf("expected empty graph after deletion, got %d entities", len(finalRead.StructuredContent.Entities)) + } + }) + } +} + +// TestMCPErrorHandling tests error scenarios through MCP layer. +func TestMCPErrorHandling(t *testing.T) { + factories := getStoreFactories() + + for name, factory := range factories { + t.Run(name, func(t *testing.T) { + s := factory(t) + kb := knowledgeBase{s: s} + + ctx := context.Background() + serverSession := &mcp.ServerSession{} + + // Test adding observations to non-existent entity + addObsParams := &mcp.CallToolParamsFor[AddObservationsArgs]{ + Arguments: AddObservationsArgs{ + Observations: []Observation{ + { + EntityName: "NonExistentEntity", + Contents: []string{"This should fail"}, + }, + }, + }, + } + + obsResult, err := kb.AddObservations(ctx, serverSession, addObsParams) + if err != nil { + t.Fatalf("MCP AddObservations call failed: %v", err) + } + if !obsResult.IsError { + t.Errorf("expected MCP AddObservations to return error for non-existent entity") + } + if len(obsResult.Content) == 0 { + t.Errorf("expected error content in MCP response") + } + }) + } +} + +// TestMCPResponseFormat verifies MCP response format consistency. +func TestMCPResponseFormat(t *testing.T) { + s := &memoryStore{} + kb := knowledgeBase{s: s} + + ctx := context.Background() + serverSession := &mcp.ServerSession{} + + // Test CreateEntities response format + createParams := &mcp.CallToolParamsFor[CreateEntitiesArgs]{ + Arguments: CreateEntitiesArgs{ + Entities: []Entity{ + {Name: "FormatTest", EntityType: "Test"}, + }, + }, + } + + result, err := kb.CreateEntities(ctx, serverSession, createParams) + if err != nil { + t.Fatalf("CreateEntities failed: %v", err) + } + + // Verify response has both Content and StructuredContent + if len(result.Content) == 0 { + t.Errorf("expected Content field to be populated") + } + if len(result.StructuredContent.Entities) == 0 { + t.Errorf("expected StructuredContent.Entities to be populated") + } + + // Verify Content contains valid JSON + if textContent, ok := result.Content[0].(*mcp.TextContent); ok { + var entities []Entity + if err := json.Unmarshal([]byte(textContent.Text), &entities); err != nil { + t.Errorf("Content field should contain valid JSON: %v", err) + } + } else { + t.Errorf("expected Content[0] to be TextContent") + } +} diff --git a/examples/memory/main.go b/examples/memory/main.go new file mode 100644 index 00000000..b0a6b8f0 --- /dev/null +++ b/examples/memory/main.go @@ -0,0 +1,162 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package main + +import ( + "context" + "flag" + "log" + "net/http" + "os" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") + memoryFilePath = flag.String("memory", "", "if set, persist the knowledge base to this file; otherwise, it will be stored in memory and lost on exit") +) + +// HiArgs defines arguments for the greeting tool. +type HiArgs struct { + Name string `json:"name"` +} + +// Entity represents a knowledge graph node with observations. +type Entity struct { + Name string `json:"name"` + EntityType string `json:"entityType"` + Observations []string `json:"observations"` +} + +// Relation represents a directed edge between two entities. +type Relation struct { + From string `json:"from"` + To string `json:"to"` + RelationType string `json:"relationType"` +} + +// Observation contains facts about an entity. +type Observation struct { + EntityName string `json:"entityName"` + Contents []string `json:"contents"` + + Observations []string `json:"observations,omitempty"` // Used for deletion operations +} + +// KnowledgeGraph represents the complete graph structure. +type KnowledgeGraph struct { + Entities []Entity `json:"entities"` + Relations []Relation `json:"relations"` +} + +// CreateEntitiesArgs defines the create entities tool parameters. +type CreateEntitiesArgs struct { + Entities []Entity `json:"entities"` +} + +// CreateEntitiesResult returns newly created entities. +type CreateEntitiesResult struct { + Entities []Entity `json:"entities"` +} + +// CreateRelationsArgs defines the create relations tool parameters. +type CreateRelationsArgs struct { + Relations []Relation `json:"relations"` +} + +// CreateRelationsResult returns newly created relations. +type CreateRelationsResult struct { + Relations []Relation `json:"relations"` +} + +// AddObservationsArgs defines the add observations tool parameters. +type AddObservationsArgs struct { + Observations []Observation `json:"observations"` +} + +// AddObservationsResult returns newly added observations. +type AddObservationsResult struct { + Observations []Observation `json:"observations"` +} + +// DeleteEntitiesArgs defines the delete entities tool parameters. +type DeleteEntitiesArgs struct { + EntityNames []string `json:"entityNames"` +} + +// DeleteObservationsArgs defines the delete observations tool parameters. +type DeleteObservationsArgs struct { + Deletions []Observation `json:"deletions"` +} + +// DeleteRelationsArgs defines the delete relations tool parameters. +type DeleteRelationsArgs struct { + Relations []Relation `json:"relations"` +} + +// SearchNodesArgs defines the search nodes tool parameters. +type SearchNodesArgs struct { + Query string `json:"query"` +} + +// OpenNodesArgs defines the open nodes tool parameters. +type OpenNodesArgs struct { + Names []string `json:"names"` +} + +func main() { + flag.Parse() + + // Initialize storage backend + var kbStore store + kbStore = &memoryStore{} + if *memoryFilePath != "" { + kbStore = &fileStore{path: *memoryFilePath} + } + kb := knowledgeBase{s: kbStore} + + // Setup MCP server with knowledge base tools + server := mcp.NewServer("memory", "v0.0.1", nil) + server.AddTools(mcp.NewServerTool("create_entities", "Create multiple new entities in the knowledge graph", kb.CreateEntities, mcp.Input( + mcp.Property("entities", mcp.Description("Entities to create")), + ))) + server.AddTools(mcp.NewServerTool("create_relations", "Create multiple new relations between entities", kb.CreateRelations, mcp.Input( + mcp.Property("relations", mcp.Description("Relations to create")), + ))) + server.AddTools(mcp.NewServerTool("add_observations", "Add new observations to existing entities", kb.AddObservations, mcp.Input( + mcp.Property("observations", mcp.Description("Observations to add")), + ))) + server.AddTools(mcp.NewServerTool("delete_entities", "Remove entities and their relations", kb.DeleteEntities, mcp.Input( + mcp.Property("entityNames", mcp.Description("Names of entities to delete")), + ))) + server.AddTools(mcp.NewServerTool("delete_observations", "Remove specific observations from entities", kb.DeleteObservations, mcp.Input( + mcp.Property("deletions", mcp.Description("Observations to delete")), + ))) + server.AddTools(mcp.NewServerTool("delete_relations", "Remove specific relations from the graph", kb.DeleteRelations, mcp.Input( + mcp.Property("relations", mcp.Description("Relations to delete")), + ))) + server.AddTools(mcp.NewServerTool("read_graph", "Read the entire knowledge graph", kb.ReadGraph)) + server.AddTools(mcp.NewServerTool("search_nodes", "Search for nodes based on query", kb.SearchNodes, mcp.Input( + mcp.Property("query", mcp.Description("Query string")), + ))) + server.AddTools(mcp.NewServerTool("open_nodes", "Retrieve specific nodes by name", kb.OpenNodes, mcp.Input( + mcp.Property("names", mcp.Description("Names of nodes to open")), + ))) + + // Start server with appropriate transport + if *httpAddr != "" { + handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { + return server + }, nil) + log.Printf("MCP handler listening at %s", *httpAddr) + http.ListenAndServe(*httpAddr, handler) + } else { + t := mcp.NewLoggingTransport(mcp.NewStdioTransport(), os.Stderr) + if err := server.Run(context.Background(), t); err != nil { + log.Printf("Server failed: %v", err) + } + } +} From 1cae4498a2074224d4c32eb88ed873e451ced103 Mon Sep 17 00:00:00 2001 From: MegaGrindStone Date: Tue, 8 Jul 2025 15:46:12 +0700 Subject: [PATCH 2/7] fix: use new API schema in memory example --- examples/memory/kb.go | 62 +++--------------------------- examples/memory/kb_test.go | 8 ++-- examples/memory/main.go | 77 ++++++++++++++++++++++---------------- 3 files changed, 54 insertions(+), 93 deletions(-) diff --git a/examples/memory/kb.go b/examples/memory/kb.go index b6a8c4dd..de5a0889 100644 --- a/examples/memory/kb.go +++ b/examples/memory/kb.go @@ -464,18 +464,8 @@ func (k knowledgeBase) CreateEntities(ctx context.Context, ss *mcp.ServerSession return &res, nil } - // I think marshalling the entities and pass it as a content should not be necessary, but as for now, it looks like - // the StructuredContent is not being unmarshalled in CallToolResultFor. - entitiesJSON, err := json.Marshal(entities) - if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil - } res.Content = []mcp.Content{ - &mcp.TextContent{Text: string(entitiesJSON)}, + &mcp.TextContent{Text: "Entities created successfully"}, } res.StructuredContent = CreateEntitiesResult{ @@ -497,16 +487,8 @@ func (k knowledgeBase) CreateRelations(ctx context.Context, ss *mcp.ServerSessio return &res, nil } - relationsJSON, err := json.Marshal(relations) - if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil - } res.Content = []mcp.Content{ - &mcp.TextContent{Text: string(relationsJSON)}, + &mcp.TextContent{Text: "Relations created successfully"}, } res.StructuredContent = CreateRelationsResult{ @@ -528,16 +510,8 @@ func (k knowledgeBase) AddObservations(ctx context.Context, ss *mcp.ServerSessio return &res, nil } - observationsJSON, err := json.Marshal(observations) - if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil - } res.Content = []mcp.Content{ - &mcp.TextContent{Text: string(observationsJSON)}, + &mcp.TextContent{Text: "Observations added successfully"}, } res.StructuredContent = AddObservationsResult{ @@ -616,16 +590,8 @@ func (k knowledgeBase) ReadGraph(ctx context.Context, ss *mcp.ServerSession, par return &res, nil } - graphJSON, err := json.Marshal(graph) - if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil - } res.Content = []mcp.Content{ - &mcp.TextContent{Text: string(graphJSON)}, + &mcp.TextContent{Text: "Graph read successfully"}, } res.StructuredContent = graph @@ -644,16 +610,8 @@ func (k knowledgeBase) SearchNodes(ctx context.Context, ss *mcp.ServerSession, p return &res, nil } - graphJSON, err := json.Marshal(graph) - if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil - } res.Content = []mcp.Content{ - &mcp.TextContent{Text: string(graphJSON)}, + &mcp.TextContent{Text: "Nodes searched successfully"}, } res.StructuredContent = graph @@ -672,16 +630,8 @@ func (k knowledgeBase) OpenNodes(ctx context.Context, ss *mcp.ServerSession, par return &res, nil } - graphJSON, err := json.Marshal(graph) - if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil - } res.Content = []mcp.Content{ - &mcp.TextContent{Text: string(graphJSON)}, + &mcp.TextContent{Text: "Nodes opened successfully"}, } res.StructuredContent = graph diff --git a/examples/memory/kb_test.go b/examples/memory/kb_test.go index ddd8bcc5..a0fe53d4 100644 --- a/examples/memory/kb_test.go +++ b/examples/memory/kb_test.go @@ -705,11 +705,11 @@ func TestMCPResponseFormat(t *testing.T) { t.Errorf("expected StructuredContent.Entities to be populated") } - // Verify Content contains valid JSON + // Verify Content contains simple success message if textContent, ok := result.Content[0].(*mcp.TextContent); ok { - var entities []Entity - if err := json.Unmarshal([]byte(textContent.Text), &entities); err != nil { - t.Errorf("Content field should contain valid JSON: %v", err) + expectedMessage := "Entities created successfully" + if textContent.Text != expectedMessage { + t.Errorf("expected Content field to contain '%s', got '%s'", expectedMessage, textContent.Text) } } else { t.Errorf("expected Content[0] to be TextContent") diff --git a/examples/memory/main.go b/examples/memory/main.go index b0a6b8f0..ba99f96f 100644 --- a/examples/memory/main.go +++ b/examples/memory/main.go @@ -54,7 +54,7 @@ type KnowledgeGraph struct { // CreateEntitiesArgs defines the create entities tool parameters. type CreateEntitiesArgs struct { - Entities []Entity `json:"entities"` + Entities []Entity `json:"entities" mcp:"entities to create"` } // CreateEntitiesResult returns newly created entities. @@ -64,7 +64,7 @@ type CreateEntitiesResult struct { // CreateRelationsArgs defines the create relations tool parameters. type CreateRelationsArgs struct { - Relations []Relation `json:"relations"` + Relations []Relation `json:"relations" mcp:"relations to create"` } // CreateRelationsResult returns newly created relations. @@ -74,7 +74,7 @@ type CreateRelationsResult struct { // AddObservationsArgs defines the add observations tool parameters. type AddObservationsArgs struct { - Observations []Observation `json:"observations"` + Observations []Observation `json:"observations" mcp:"observations to add"` } // AddObservationsResult returns newly added observations. @@ -84,27 +84,27 @@ type AddObservationsResult struct { // DeleteEntitiesArgs defines the delete entities tool parameters. type DeleteEntitiesArgs struct { - EntityNames []string `json:"entityNames"` + EntityNames []string `json:"entityNames" mcp:"entities to delete"` } // DeleteObservationsArgs defines the delete observations tool parameters. type DeleteObservationsArgs struct { - Deletions []Observation `json:"deletions"` + Deletions []Observation `json:"deletions" mcp:"obeservations to delete"` } // DeleteRelationsArgs defines the delete relations tool parameters. type DeleteRelationsArgs struct { - Relations []Relation `json:"relations"` + Relations []Relation `json:"relations" mcp:"relations to delete"` } // SearchNodesArgs defines the search nodes tool parameters. type SearchNodesArgs struct { - Query string `json:"query"` + Query string `json:"query" mcp:"query string"` } // OpenNodesArgs defines the open nodes tool parameters. type OpenNodesArgs struct { - Names []string `json:"names"` + Names []string `json:"names" mcp:"names of nodes to open"` } func main() { @@ -120,31 +120,42 @@ func main() { // Setup MCP server with knowledge base tools server := mcp.NewServer("memory", "v0.0.1", nil) - server.AddTools(mcp.NewServerTool("create_entities", "Create multiple new entities in the knowledge graph", kb.CreateEntities, mcp.Input( - mcp.Property("entities", mcp.Description("Entities to create")), - ))) - server.AddTools(mcp.NewServerTool("create_relations", "Create multiple new relations between entities", kb.CreateRelations, mcp.Input( - mcp.Property("relations", mcp.Description("Relations to create")), - ))) - server.AddTools(mcp.NewServerTool("add_observations", "Add new observations to existing entities", kb.AddObservations, mcp.Input( - mcp.Property("observations", mcp.Description("Observations to add")), - ))) - server.AddTools(mcp.NewServerTool("delete_entities", "Remove entities and their relations", kb.DeleteEntities, mcp.Input( - mcp.Property("entityNames", mcp.Description("Names of entities to delete")), - ))) - server.AddTools(mcp.NewServerTool("delete_observations", "Remove specific observations from entities", kb.DeleteObservations, mcp.Input( - mcp.Property("deletions", mcp.Description("Observations to delete")), - ))) - server.AddTools(mcp.NewServerTool("delete_relations", "Remove specific relations from the graph", kb.DeleteRelations, mcp.Input( - mcp.Property("relations", mcp.Description("Relations to delete")), - ))) - server.AddTools(mcp.NewServerTool("read_graph", "Read the entire knowledge graph", kb.ReadGraph)) - server.AddTools(mcp.NewServerTool("search_nodes", "Search for nodes based on query", kb.SearchNodes, mcp.Input( - mcp.Property("query", mcp.Description("Query string")), - ))) - server.AddTools(mcp.NewServerTool("open_nodes", "Retrieve specific nodes by name", kb.OpenNodes, mcp.Input( - mcp.Property("names", mcp.Description("Names of nodes to open")), - ))) + mcp.AddTool(server, &mcp.Tool{ + Name: "create_entities", + Description: "Create multiple new entities in the knowledge graph", + }, kb.CreateEntities) + mcp.AddTool(server, &mcp.Tool{ + Name: "create_relations", + Description: "Create multiple new relations between entities", + }, kb.CreateRelations) + mcp.AddTool(server, &mcp.Tool{ + Name: "add_observations", + Description: "Add new observations to existing entities", + }, kb.AddObservations) + mcp.AddTool(server, &mcp.Tool{ + Name: "delete_entities", + Description: "Remove entities and their relations", + }, kb.DeleteEntities) + mcp.AddTool(server, &mcp.Tool{ + Name: "delete_observations", + Description: "Remove specific observations from entities", + }, kb.DeleteObservations) + mcp.AddTool(server, &mcp.Tool{ + Name: "delete_relations", + Description: "Remove specific relations from the graph", + }, kb.DeleteRelations) + mcp.AddTool(server, &mcp.Tool{ + Name: "read_graph", + Description: "Read the entire knowledge graph", + }, kb.ReadGraph) + mcp.AddTool(server, &mcp.Tool{ + Name: "search_nodes", + Description: "Search for nodes based on query", + }, kb.SearchNodes) + mcp.AddTool(server, &mcp.Tool{ + Name: "open_nodes", + Description: "Retrieve specific nodes by name", + }, kb.OpenNodes) // Start server with appropriate transport if *httpAddr != "" { From 0c74f309bcb3ffe0a011967c8221133727731352 Mon Sep 17 00:00:00 2001 From: MegaGrindStone Date: Tue, 8 Jul 2025 16:28:19 +0700 Subject: [PATCH 3/7] refactor: address PR feedback in memory example --- examples/memory/kb.go | 132 ++++++++++++++---------------------------- 1 file changed, 44 insertions(+), 88 deletions(-) diff --git a/examples/memory/kb.go b/examples/memory/kb.go index de5a0889..1d07c93e 100644 --- a/examples/memory/kb.go +++ b/examples/memory/kb.go @@ -47,7 +47,7 @@ func (fs *fileStore) Read() ([]byte, error) { data, err := os.ReadFile(fs.path) if err != nil { if os.IsNotExist(err) { - return []byte{}, nil + return nil, nil } return nil, fmt.Errorf("failed to read file %s: %w", fs.path, err) } @@ -98,10 +98,7 @@ func (k knowledgeBase) loadGraph() (KnowledgeGraph, error) { return KnowledgeGraph{}, fmt.Errorf("failed to unmarshal from store: %w", err) } - graph := KnowledgeGraph{ - Entities: []Entity{}, - Relations: []Relation{}, - } + graph := KnowledgeGraph{} for _, item := range items { switch item.Type { @@ -157,6 +154,7 @@ func (k knowledgeBase) saveGraph(graph KnowledgeGraph) error { } // createEntities adds new entities to the graph, skipping duplicates by name. +// Returns the new entities that were actually added. func (k knowledgeBase) createEntities(entities []Entity) ([]Entity, error) { graph, err := k.loadGraph() if err != nil { @@ -165,15 +163,7 @@ func (k knowledgeBase) createEntities(entities []Entity) ([]Entity, error) { var newEntities []Entity for _, entity := range entities { - exists := false - for _, existingEntity := range graph.Entities { - if existingEntity.Name == entity.Name { - exists = true - break - } - } - - if !exists { + if !slices.ContainsFunc(graph.Entities, func(e Entity) bool { return e.Name == entity.Name }) { newEntities = append(newEntities, entity) graph.Entities = append(graph.Entities, entity) } @@ -187,6 +177,7 @@ func (k knowledgeBase) createEntities(entities []Entity) ([]Entity, error) { } // createRelations adds new relations to the graph, skipping exact duplicates. +// Returns the new relations that were actually added. func (k knowledgeBase) createRelations(relations []Relation) ([]Relation, error) { graph, err := k.loadGraph() if err != nil { @@ -195,16 +186,11 @@ func (k knowledgeBase) createRelations(relations []Relation) ([]Relation, error) var newRelations []Relation for _, relation := range relations { - exists := false - for _, existingRelation := range graph.Relations { - if existingRelation.From == relation.From && - existingRelation.To == relation.To && - existingRelation.RelationType == relation.RelationType { - exists = true - break - } - } - + exists := slices.ContainsFunc(graph.Relations, func(r Relation) bool { + return r.From == relation.From && + r.To == relation.To && + r.RelationType == relation.RelationType + }) if !exists { newRelations = append(newRelations, relation) graph.Relations = append(graph.Relations, relation) @@ -219,6 +205,7 @@ func (k knowledgeBase) createRelations(relations []Relation) ([]Relation, error) } // addObservations appends new observations to existing entities. +// Returns the new observations that were actually added. func (k knowledgeBase) addObservations(observations []Observation) ([]Observation, error) { graph, err := k.loadGraph() if err != nil { @@ -228,23 +215,14 @@ func (k knowledgeBase) addObservations(observations []Observation) ([]Observatio var results []Observation for _, obs := range observations { - entityIndex := -1 - for i, entity := range graph.Entities { - if entity.Name == obs.EntityName { - entityIndex = i - break - } - } - + entityIndex := slices.IndexFunc(graph.Entities, func(e Entity) bool { return e.Name == obs.EntityName }) if entityIndex == -1 { return nil, fmt.Errorf("entity with name %s not found", obs.EntityName) } var newObservations []string for _, content := range obs.Contents { - exists := slices.Contains(graph.Entities[entityIndex].Observations, content) - - if !exists { + if !slices.Contains(graph.Entities[entityIndex].Observations, content) { newObservations = append(newObservations, content) graph.Entities[entityIndex].Observations = append(graph.Entities[entityIndex].Observations, content) } @@ -276,23 +254,15 @@ func (k knowledgeBase) deleteEntities(entityNames []string) error { entitiesToDelete[name] = true } - // Filter entities - var filteredEntities []Entity - for _, entity := range graph.Entities { - if !entitiesToDelete[entity.Name] { - filteredEntities = append(filteredEntities, entity) - } - } - graph.Entities = filteredEntities + // Filter entities using slices.DeleteFunc + graph.Entities = slices.DeleteFunc(graph.Entities, func(entity Entity) bool { + return entitiesToDelete[entity.Name] + }) - // Filter relations - var filteredRelations []Relation - for _, relation := range graph.Relations { - if !entitiesToDelete[relation.From] && !entitiesToDelete[relation.To] { - filteredRelations = append(filteredRelations, relation) - } - } - graph.Relations = filteredRelations + // Filter relations using slices.DeleteFunc + graph.Relations = slices.DeleteFunc(graph.Relations, func(relation Relation) bool { + return entitiesToDelete[relation.From] || entitiesToDelete[relation.To] + }) return k.saveGraph(graph) } @@ -305,26 +275,23 @@ func (k knowledgeBase) deleteObservations(deletions []Observation) error { } for _, deletion := range deletions { - for i, entity := range graph.Entities { - if entity.Name == deletion.EntityName { - // Create a map for quick lookup - observationsToDelete := make(map[string]bool) - for _, observation := range deletion.Observations { - observationsToDelete[observation] = true - } - - // Filter observations - var filteredObservations []string - for _, observation := range entity.Observations { - if !observationsToDelete[observation] { - filteredObservations = append(filteredObservations, observation) - } - } - - graph.Entities[i].Observations = filteredObservations - break - } + entityIndex := slices.IndexFunc(graph.Entities, func(e Entity) bool { + return e.Name == deletion.EntityName + }) + if entityIndex == -1 { + continue } + + // Create a map for quick lookup + observationsToDelete := make(map[string]bool) + for _, observation := range deletion.Observations { + observationsToDelete[observation] = true + } + + // Filter observations using slices.DeleteFunc + graph.Entities[entityIndex].Observations = slices.DeleteFunc(graph.Entities[entityIndex].Observations, func(observation string) bool { + return observationsToDelete[observation] + }) } return k.saveGraph(graph) @@ -337,25 +304,14 @@ func (k knowledgeBase) deleteRelations(relations []Relation) error { return err } - var filteredRelations []Relation - for _, existingRelation := range graph.Relations { - shouldKeep := true - - for _, relationToDelete := range relations { - if existingRelation.From == relationToDelete.From && + // Filter relations using slices.DeleteFunc and slices.ContainsFunc + graph.Relations = slices.DeleteFunc(graph.Relations, func(existingRelation Relation) bool { + return slices.ContainsFunc(relations, func(relationToDelete Relation) bool { + return existingRelation.From == relationToDelete.From && existingRelation.To == relationToDelete.To && - existingRelation.RelationType == relationToDelete.RelationType { - shouldKeep = false - break - } - } - - if shouldKeep { - filteredRelations = append(filteredRelations, existingRelation) - } - } - - graph.Relations = filteredRelations + existingRelation.RelationType == relationToDelete.RelationType + }) + }) return k.saveGraph(graph) } From 3d0fa949ee40d62ceb67f09fbfbeef1a0ecbc58b Mon Sep 17 00:00:00 2001 From: MegaGrindStone Date: Tue, 8 Jul 2025 16:38:35 +0700 Subject: [PATCH 4/7] refactor: remove readGraph proxy function in memory example --- examples/memory/kb.go | 6 +----- examples/memory/kb_test.go | 8 ++++---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/examples/memory/kb.go b/examples/memory/kb.go index 1d07c93e..8d521300 100644 --- a/examples/memory/kb.go +++ b/examples/memory/kb.go @@ -315,10 +315,6 @@ func (k knowledgeBase) deleteRelations(relations []Relation) error { return k.saveGraph(graph) } -// readGraph returns the complete knowledge graph. -func (k knowledgeBase) readGraph() (KnowledgeGraph, error) { - return k.loadGraph() -} // searchNodes filters entities and relations matching the query string. func (k knowledgeBase) searchNodes(query string) (KnowledgeGraph, error) { @@ -537,7 +533,7 @@ func (k knowledgeBase) DeleteRelations(ctx context.Context, ss *mcp.ServerSessio func (k knowledgeBase) ReadGraph(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[struct{}]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { var res mcp.CallToolResultFor[KnowledgeGraph] - graph, err := k.readGraph() + graph, err := k.loadGraph() if err != nil { res.IsError = true res.Content = []mcp.Content{ diff --git a/examples/memory/kb_test.go b/examples/memory/kb_test.go index a0fe53d4..2a3aaefd 100644 --- a/examples/memory/kb_test.go +++ b/examples/memory/kb_test.go @@ -74,7 +74,7 @@ func TestKnowledgeBaseOperations(t *testing.T) { } // Verify entities persist - graph, err = kb.readGraph() + graph, err = kb.loadGraph() if err != nil { t.Fatalf("failed to read graph: %v", err) } @@ -146,7 +146,7 @@ func TestKnowledgeBaseOperations(t *testing.T) { } // Confirm observation removal - graph, _ = kb.readGraph() + graph, _ = kb.loadGraph() aliceFound := false for _, e := range graph.Entities { if e.Name == "Alice" { @@ -169,7 +169,7 @@ func TestKnowledgeBaseOperations(t *testing.T) { } // Confirm relation removal - graph, _ = kb.readGraph() + graph, _ = kb.loadGraph() if len(graph.Relations) != 0 { t.Errorf("expected 0 relations after deletion, got %d", len(graph.Relations)) } @@ -181,7 +181,7 @@ func TestKnowledgeBaseOperations(t *testing.T) { } // Confirm entity removal - graph, _ = kb.readGraph() + graph, _ = kb.loadGraph() if len(graph.Entities) != 1 || graph.Entities[0].Name != "Bob" { t.Errorf("expected only Bob to remain after deleting Alice, got %+v", graph.Entities) } From 4f69af6bd68012e6a03f07d0ac3b7121b2bd6644 Mon Sep 17 00:00:00 2001 From: MegaGrindStone Date: Tue, 8 Jul 2025 16:40:08 +0700 Subject: [PATCH 5/7] refactor: fix lint in memory example --- examples/memory/kb.go | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/memory/kb.go b/examples/memory/kb.go index 8d521300..3c68f73d 100644 --- a/examples/memory/kb.go +++ b/examples/memory/kb.go @@ -315,7 +315,6 @@ func (k knowledgeBase) deleteRelations(relations []Relation) error { return k.saveGraph(graph) } - // searchNodes filters entities and relations matching the query string. func (k knowledgeBase) searchNodes(query string) (KnowledgeGraph, error) { graph, err := k.loadGraph() From c0a5c7b3dd07f3d43eca6f50271947844f0ddb5b Mon Sep 17 00:00:00 2001 From: MegaGrindStone Date: Tue, 8 Jul 2025 22:48:55 +0700 Subject: [PATCH 6/7] refactor: address PR feedback in memory example --- examples/memory/kb.go | 88 +++++++++++++++++--------------------- examples/memory/kb_test.go | 84 +++++++++++++++--------------------- examples/memory/main.go | 28 ------------ 3 files changed, 75 insertions(+), 125 deletions(-) diff --git a/examples/memory/kb.go b/examples/memory/kb.go index 3c68f73d..a274f057 100644 --- a/examples/memory/kb.go +++ b/examples/memory/kb.go @@ -15,6 +15,34 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) +// Entity represents a knowledge graph node with observations. +type Entity struct { + Name string `json:"name"` + EntityType string `json:"entityType"` + Observations []string `json:"observations"` +} + +// Relation represents a directed edge between two entities. +type Relation struct { + From string `json:"from"` + To string `json:"to"` + RelationType string `json:"relationType"` +} + +// Observation contains facts about an entity. +type Observation struct { + EntityName string `json:"entityName"` + Contents []string `json:"contents"` + + Observations []string `json:"observations,omitempty"` // Used for deletion operations +} + +// KnowledgeGraph represents the complete graph structure. +type KnowledgeGraph struct { + Entities []Entity `json:"entities"` + Relations []Relation `json:"relations"` +} + // store provides persistence interface for knowledge base data. type store interface { Read() ([]byte, error) @@ -154,7 +182,7 @@ func (k knowledgeBase) saveGraph(graph KnowledgeGraph) error { } // createEntities adds new entities to the graph, skipping duplicates by name. -// Returns the new entities that were actually added. +// It returns the new entities that were actually added. func (k knowledgeBase) createEntities(entities []Entity) ([]Entity, error) { graph, err := k.loadGraph() if err != nil { @@ -177,7 +205,7 @@ func (k knowledgeBase) createEntities(entities []Entity) ([]Entity, error) { } // createRelations adds new relations to the graph, skipping exact duplicates. -// Returns the new relations that were actually added. +// It returns the new relations that were actually added. func (k knowledgeBase) createRelations(relations []Relation) ([]Relation, error) { graph, err := k.loadGraph() if err != nil { @@ -205,7 +233,7 @@ func (k knowledgeBase) createRelations(relations []Relation) ([]Relation, error) } // addObservations appends new observations to existing entities. -// Returns the new observations that were actually added. +// It returns the new observations that were actually added. func (k knowledgeBase) addObservations(observations []Observation) ([]Observation, error) { graph, err := k.loadGraph() if err != nil { @@ -408,11 +436,7 @@ func (k knowledgeBase) CreateEntities(ctx context.Context, ss *mcp.ServerSession entities, err := k.createEntities(params.Arguments.Entities) if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil + return nil, err } res.Content = []mcp.Content{ @@ -431,11 +455,7 @@ func (k knowledgeBase) CreateRelations(ctx context.Context, ss *mcp.ServerSessio relations, err := k.createRelations(params.Arguments.Relations) if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil + return nil, err } res.Content = []mcp.Content{ @@ -454,11 +474,7 @@ func (k knowledgeBase) AddObservations(ctx context.Context, ss *mcp.ServerSessio observations, err := k.addObservations(params.Arguments.Observations) if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil + return nil, err } res.Content = []mcp.Content{ @@ -477,11 +493,7 @@ func (k knowledgeBase) DeleteEntities(ctx context.Context, ss *mcp.ServerSession err := k.deleteEntities(params.Arguments.EntityNames) if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil + return nil, err } res.Content = []mcp.Content{ @@ -496,11 +508,7 @@ func (k knowledgeBase) DeleteObservations(ctx context.Context, ss *mcp.ServerSes err := k.deleteObservations(params.Arguments.Deletions) if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil + return nil, err } res.Content = []mcp.Content{ @@ -515,11 +523,7 @@ func (k knowledgeBase) DeleteRelations(ctx context.Context, ss *mcp.ServerSessio err := k.deleteRelations(params.Arguments.Relations) if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil + return nil, err } res.Content = []mcp.Content{ @@ -534,11 +538,7 @@ func (k knowledgeBase) ReadGraph(ctx context.Context, ss *mcp.ServerSession, par graph, err := k.loadGraph() if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil + return nil, err } res.Content = []mcp.Content{ @@ -554,11 +554,7 @@ func (k knowledgeBase) SearchNodes(ctx context.Context, ss *mcp.ServerSession, p graph, err := k.searchNodes(params.Arguments.Query) if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil + return nil, err } res.Content = []mcp.Content{ @@ -574,11 +570,7 @@ func (k knowledgeBase) OpenNodes(ctx context.Context, ss *mcp.ServerSession, par graph, err := k.openNodes(params.Arguments.Names) if err != nil { - res.IsError = true - res.Content = []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - } - return &res, nil + return nil, err } res.Content = []mcp.Content{ diff --git a/examples/memory/kb_test.go b/examples/memory/kb_test.go index 2a3aaefd..57e8506c 100644 --- a/examples/memory/kb_test.go +++ b/examples/memory/kb_test.go @@ -11,13 +11,15 @@ import ( "os" "path/filepath" "reflect" + "slices" + "strings" "testing" "github.com/modelcontextprotocol/go-sdk/mcp" ) -// getStoreFactories provides test factories for both storage implementations. -func getStoreFactories() map[string]func(t *testing.T) store { +// stores provides test factories for both storage implementations. +func stores() map[string]func(t *testing.T) store { return map[string]func(t *testing.T) store{ "file": func(t *testing.T) store { tempDir, err := os.MkdirTemp("", "kb-test-file-*") @@ -35,11 +37,9 @@ func getStoreFactories() map[string]func(t *testing.T) store { // TestKnowledgeBaseOperations verifies CRUD operations work correctly. func TestKnowledgeBaseOperations(t *testing.T) { - factories := getStoreFactories() - - for name, factory := range factories { + for name, newStore := range stores() { t.Run(name, func(t *testing.T) { - s := factory(t) + s := newStore(t) kb := knowledgeBase{s: s} // Verify empty graph loads correctly @@ -147,19 +147,16 @@ func TestKnowledgeBaseOperations(t *testing.T) { // Confirm observation removal graph, _ = kb.loadGraph() - aliceFound := false - for _, e := range graph.Entities { - if e.Name == "Alice" { - aliceFound = true - for _, obs := range e.Observations { - if obs == "Works as developer" { - t.Errorf("observation 'Works as developer' should have been deleted") - } - } - } - } - if !aliceFound { + aliceIndex := slices.IndexFunc(graph.Entities, func(e Entity) bool { + return e.Name == "Alice" + }) + if aliceIndex == -1 { t.Errorf("entity 'Alice' not found after deleting observation") + } else { + alice := graph.Entities[aliceIndex] + if slices.Contains(alice.Observations, "Works as developer") { + t.Errorf("observation 'Works as developer' should have been deleted") + } } // Remove relations @@ -191,11 +188,9 @@ func TestKnowledgeBaseOperations(t *testing.T) { // TestSaveAndLoadGraph ensures data persists correctly across save/load cycles. func TestSaveAndLoadGraph(t *testing.T) { - factories := getStoreFactories() - - for name, factory := range factories { + for name, newStore := range stores() { t.Run(name, func(t *testing.T) { - s := factory(t) + s := newStore(t) kb := knowledgeBase{s: s} // Setup test data @@ -251,11 +246,9 @@ func TestSaveAndLoadGraph(t *testing.T) { // TestDuplicateEntitiesAndRelations verifies duplicate prevention logic. func TestDuplicateEntitiesAndRelations(t *testing.T) { - factories := getStoreFactories() - - for name, factory := range factories { + for name, newStore := range stores() { t.Run(name, func(t *testing.T) { - s := factory(t) + s := newStore(t) kb := knowledgeBase{s: s} // Setup initial state @@ -355,10 +348,9 @@ func TestErrorHandling(t *testing.T) { } }) - factories := getStoreFactories() - for name, factory := range factories { + for name, newStore := range stores() { t.Run(fmt.Sprintf("AddObservationToNonExistentEntity_%s", name), func(t *testing.T) { - s := factory(t) + s := newStore(t) kb := knowledgeBase{s: s} // Setup valid entity for comparison @@ -385,11 +377,9 @@ func TestErrorHandling(t *testing.T) { // TestFileFormatting verifies the JSON storage format structure. func TestFileFormatting(t *testing.T) { - factories := getStoreFactories() - - for name, factory := range factories { + for name, newStore := range stores() { t.Run(name, func(t *testing.T) { - s := factory(t) + s := newStore(t) kb := knowledgeBase{s: s} // Setup test entity @@ -438,11 +428,9 @@ func TestFileFormatting(t *testing.T) { // TestMCPServerIntegration tests the knowledge base through MCP server layer. func TestMCPServerIntegration(t *testing.T) { - factories := getStoreFactories() - - for name, factory := range factories { + for name, newStore := range stores() { t.Run(name, func(t *testing.T) { - s := factory(t) + s := newStore(t) kb := knowledgeBase{s: s} // Create mock server session @@ -639,11 +627,9 @@ func TestMCPServerIntegration(t *testing.T) { // TestMCPErrorHandling tests error scenarios through MCP layer. func TestMCPErrorHandling(t *testing.T) { - factories := getStoreFactories() - - for name, factory := range factories { + for name, newStore := range stores() { t.Run(name, func(t *testing.T) { - s := factory(t) + s := newStore(t) kb := knowledgeBase{s: s} ctx := context.Background() @@ -661,15 +647,15 @@ func TestMCPErrorHandling(t *testing.T) { }, } - obsResult, err := kb.AddObservations(ctx, serverSession, addObsParams) - if err != nil { - t.Fatalf("MCP AddObservations call failed: %v", err) - } - if !obsResult.IsError { + _, err := kb.AddObservations(ctx, serverSession, addObsParams) + if err == nil { t.Errorf("expected MCP AddObservations to return error for non-existent entity") - } - if len(obsResult.Content) == 0 { - t.Errorf("expected error content in MCP response") + } else { + // Verify the error message contains expected text + expectedErrorMsg := "entity with name NonExistentEntity not found" + if !strings.Contains(err.Error(), expectedErrorMsg) { + t.Errorf("expected error message to contain '%s', got: %v", expectedErrorMsg, err) + } } }) } diff --git a/examples/memory/main.go b/examples/memory/main.go index ba99f96f..d3d78110 100644 --- a/examples/memory/main.go +++ b/examples/memory/main.go @@ -24,34 +24,6 @@ type HiArgs struct { Name string `json:"name"` } -// Entity represents a knowledge graph node with observations. -type Entity struct { - Name string `json:"name"` - EntityType string `json:"entityType"` - Observations []string `json:"observations"` -} - -// Relation represents a directed edge between two entities. -type Relation struct { - From string `json:"from"` - To string `json:"to"` - RelationType string `json:"relationType"` -} - -// Observation contains facts about an entity. -type Observation struct { - EntityName string `json:"entityName"` - Contents []string `json:"contents"` - - Observations []string `json:"observations,omitempty"` // Used for deletion operations -} - -// KnowledgeGraph represents the complete graph structure. -type KnowledgeGraph struct { - Entities []Entity `json:"entities"` - Relations []Relation `json:"relations"` -} - // CreateEntitiesArgs defines the create entities tool parameters. type CreateEntitiesArgs struct { Entities []Entity `json:"entities" mcp:"entities to create"` From bed7d19958548a8a03f7c0e07b149518c12e763e Mon Sep 17 00:00:00 2001 From: MegaGrindStone Date: Wed, 9 Jul 2025 00:00:06 +0700 Subject: [PATCH 7/7] refactor: simplify test variable name in memory example --- examples/memory/kb_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/memory/kb_test.go b/examples/memory/kb_test.go index 57e8506c..e4fbacc9 100644 --- a/examples/memory/kb_test.go +++ b/examples/memory/kb_test.go @@ -652,9 +652,9 @@ func TestMCPErrorHandling(t *testing.T) { t.Errorf("expected MCP AddObservations to return error for non-existent entity") } else { // Verify the error message contains expected text - expectedErrorMsg := "entity with name NonExistentEntity not found" - if !strings.Contains(err.Error(), expectedErrorMsg) { - t.Errorf("expected error message to contain '%s', got: %v", expectedErrorMsg, err) + want := "entity with name NonExistentEntity not found" + if !strings.Contains(err.Error(), want) { + t.Errorf("expected error message to contain '%s', got: %v", want, err) } } })