Skip to content

Commit c0a5c7b

Browse files
refactor: address PR feedback in memory example
1 parent 4f69af6 commit c0a5c7b

File tree

3 files changed

+75
-125
lines changed

3 files changed

+75
-125
lines changed

examples/memory/kb.go

Lines changed: 40 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,34 @@ import (
1515
"github.com/modelcontextprotocol/go-sdk/mcp"
1616
)
1717

18+
// Entity represents a knowledge graph node with observations.
19+
type Entity struct {
20+
Name string `json:"name"`
21+
EntityType string `json:"entityType"`
22+
Observations []string `json:"observations"`
23+
}
24+
25+
// Relation represents a directed edge between two entities.
26+
type Relation struct {
27+
From string `json:"from"`
28+
To string `json:"to"`
29+
RelationType string `json:"relationType"`
30+
}
31+
32+
// Observation contains facts about an entity.
33+
type Observation struct {
34+
EntityName string `json:"entityName"`
35+
Contents []string `json:"contents"`
36+
37+
Observations []string `json:"observations,omitempty"` // Used for deletion operations
38+
}
39+
40+
// KnowledgeGraph represents the complete graph structure.
41+
type KnowledgeGraph struct {
42+
Entities []Entity `json:"entities"`
43+
Relations []Relation `json:"relations"`
44+
}
45+
1846
// store provides persistence interface for knowledge base data.
1947
type store interface {
2048
Read() ([]byte, error)
@@ -154,7 +182,7 @@ func (k knowledgeBase) saveGraph(graph KnowledgeGraph) error {
154182
}
155183

156184
// createEntities adds new entities to the graph, skipping duplicates by name.
157-
// Returns the new entities that were actually added.
185+
// It returns the new entities that were actually added.
158186
func (k knowledgeBase) createEntities(entities []Entity) ([]Entity, error) {
159187
graph, err := k.loadGraph()
160188
if err != nil {
@@ -177,7 +205,7 @@ func (k knowledgeBase) createEntities(entities []Entity) ([]Entity, error) {
177205
}
178206

179207
// createRelations adds new relations to the graph, skipping exact duplicates.
180-
// Returns the new relations that were actually added.
208+
// It returns the new relations that were actually added.
181209
func (k knowledgeBase) createRelations(relations []Relation) ([]Relation, error) {
182210
graph, err := k.loadGraph()
183211
if err != nil {
@@ -205,7 +233,7 @@ func (k knowledgeBase) createRelations(relations []Relation) ([]Relation, error)
205233
}
206234

207235
// addObservations appends new observations to existing entities.
208-
// Returns the new observations that were actually added.
236+
// It returns the new observations that were actually added.
209237
func (k knowledgeBase) addObservations(observations []Observation) ([]Observation, error) {
210238
graph, err := k.loadGraph()
211239
if err != nil {
@@ -408,11 +436,7 @@ func (k knowledgeBase) CreateEntities(ctx context.Context, ss *mcp.ServerSession
408436

409437
entities, err := k.createEntities(params.Arguments.Entities)
410438
if err != nil {
411-
res.IsError = true
412-
res.Content = []mcp.Content{
413-
&mcp.TextContent{Text: err.Error()},
414-
}
415-
return &res, nil
439+
return nil, err
416440
}
417441

418442
res.Content = []mcp.Content{
@@ -431,11 +455,7 @@ func (k knowledgeBase) CreateRelations(ctx context.Context, ss *mcp.ServerSessio
431455

432456
relations, err := k.createRelations(params.Arguments.Relations)
433457
if err != nil {
434-
res.IsError = true
435-
res.Content = []mcp.Content{
436-
&mcp.TextContent{Text: err.Error()},
437-
}
438-
return &res, nil
458+
return nil, err
439459
}
440460

441461
res.Content = []mcp.Content{
@@ -454,11 +474,7 @@ func (k knowledgeBase) AddObservations(ctx context.Context, ss *mcp.ServerSessio
454474

455475
observations, err := k.addObservations(params.Arguments.Observations)
456476
if err != nil {
457-
res.IsError = true
458-
res.Content = []mcp.Content{
459-
&mcp.TextContent{Text: err.Error()},
460-
}
461-
return &res, nil
477+
return nil, err
462478
}
463479

464480
res.Content = []mcp.Content{
@@ -477,11 +493,7 @@ func (k knowledgeBase) DeleteEntities(ctx context.Context, ss *mcp.ServerSession
477493

478494
err := k.deleteEntities(params.Arguments.EntityNames)
479495
if err != nil {
480-
res.IsError = true
481-
res.Content = []mcp.Content{
482-
&mcp.TextContent{Text: err.Error()},
483-
}
484-
return &res, nil
496+
return nil, err
485497
}
486498

487499
res.Content = []mcp.Content{
@@ -496,11 +508,7 @@ func (k knowledgeBase) DeleteObservations(ctx context.Context, ss *mcp.ServerSes
496508

497509
err := k.deleteObservations(params.Arguments.Deletions)
498510
if err != nil {
499-
res.IsError = true
500-
res.Content = []mcp.Content{
501-
&mcp.TextContent{Text: err.Error()},
502-
}
503-
return &res, nil
511+
return nil, err
504512
}
505513

506514
res.Content = []mcp.Content{
@@ -515,11 +523,7 @@ func (k knowledgeBase) DeleteRelations(ctx context.Context, ss *mcp.ServerSessio
515523

516524
err := k.deleteRelations(params.Arguments.Relations)
517525
if err != nil {
518-
res.IsError = true
519-
res.Content = []mcp.Content{
520-
&mcp.TextContent{Text: err.Error()},
521-
}
522-
return &res, nil
526+
return nil, err
523527
}
524528

525529
res.Content = []mcp.Content{
@@ -534,11 +538,7 @@ func (k knowledgeBase) ReadGraph(ctx context.Context, ss *mcp.ServerSession, par
534538

535539
graph, err := k.loadGraph()
536540
if err != nil {
537-
res.IsError = true
538-
res.Content = []mcp.Content{
539-
&mcp.TextContent{Text: err.Error()},
540-
}
541-
return &res, nil
541+
return nil, err
542542
}
543543

544544
res.Content = []mcp.Content{
@@ -554,11 +554,7 @@ func (k knowledgeBase) SearchNodes(ctx context.Context, ss *mcp.ServerSession, p
554554

555555
graph, err := k.searchNodes(params.Arguments.Query)
556556
if err != nil {
557-
res.IsError = true
558-
res.Content = []mcp.Content{
559-
&mcp.TextContent{Text: err.Error()},
560-
}
561-
return &res, nil
557+
return nil, err
562558
}
563559

564560
res.Content = []mcp.Content{
@@ -574,11 +570,7 @@ func (k knowledgeBase) OpenNodes(ctx context.Context, ss *mcp.ServerSession, par
574570

575571
graph, err := k.openNodes(params.Arguments.Names)
576572
if err != nil {
577-
res.IsError = true
578-
res.Content = []mcp.Content{
579-
&mcp.TextContent{Text: err.Error()},
580-
}
581-
return &res, nil
573+
return nil, err
582574
}
583575

584576
res.Content = []mcp.Content{

examples/memory/kb_test.go

Lines changed: 35 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ import (
1111
"os"
1212
"path/filepath"
1313
"reflect"
14+
"slices"
15+
"strings"
1416
"testing"
1517

1618
"github.com/modelcontextprotocol/go-sdk/mcp"
1719
)
1820

19-
// getStoreFactories provides test factories for both storage implementations.
20-
func getStoreFactories() map[string]func(t *testing.T) store {
21+
// stores provides test factories for both storage implementations.
22+
func stores() map[string]func(t *testing.T) store {
2123
return map[string]func(t *testing.T) store{
2224
"file": func(t *testing.T) store {
2325
tempDir, err := os.MkdirTemp("", "kb-test-file-*")
@@ -35,11 +37,9 @@ func getStoreFactories() map[string]func(t *testing.T) store {
3537

3638
// TestKnowledgeBaseOperations verifies CRUD operations work correctly.
3739
func TestKnowledgeBaseOperations(t *testing.T) {
38-
factories := getStoreFactories()
39-
40-
for name, factory := range factories {
40+
for name, newStore := range stores() {
4141
t.Run(name, func(t *testing.T) {
42-
s := factory(t)
42+
s := newStore(t)
4343
kb := knowledgeBase{s: s}
4444

4545
// Verify empty graph loads correctly
@@ -147,19 +147,16 @@ func TestKnowledgeBaseOperations(t *testing.T) {
147147

148148
// Confirm observation removal
149149
graph, _ = kb.loadGraph()
150-
aliceFound := false
151-
for _, e := range graph.Entities {
152-
if e.Name == "Alice" {
153-
aliceFound = true
154-
for _, obs := range e.Observations {
155-
if obs == "Works as developer" {
156-
t.Errorf("observation 'Works as developer' should have been deleted")
157-
}
158-
}
159-
}
160-
}
161-
if !aliceFound {
150+
aliceIndex := slices.IndexFunc(graph.Entities, func(e Entity) bool {
151+
return e.Name == "Alice"
152+
})
153+
if aliceIndex == -1 {
162154
t.Errorf("entity 'Alice' not found after deleting observation")
155+
} else {
156+
alice := graph.Entities[aliceIndex]
157+
if slices.Contains(alice.Observations, "Works as developer") {
158+
t.Errorf("observation 'Works as developer' should have been deleted")
159+
}
163160
}
164161

165162
// Remove relations
@@ -191,11 +188,9 @@ func TestKnowledgeBaseOperations(t *testing.T) {
191188

192189
// TestSaveAndLoadGraph ensures data persists correctly across save/load cycles.
193190
func TestSaveAndLoadGraph(t *testing.T) {
194-
factories := getStoreFactories()
195-
196-
for name, factory := range factories {
191+
for name, newStore := range stores() {
197192
t.Run(name, func(t *testing.T) {
198-
s := factory(t)
193+
s := newStore(t)
199194
kb := knowledgeBase{s: s}
200195

201196
// Setup test data
@@ -251,11 +246,9 @@ func TestSaveAndLoadGraph(t *testing.T) {
251246

252247
// TestDuplicateEntitiesAndRelations verifies duplicate prevention logic.
253248
func TestDuplicateEntitiesAndRelations(t *testing.T) {
254-
factories := getStoreFactories()
255-
256-
for name, factory := range factories {
249+
for name, newStore := range stores() {
257250
t.Run(name, func(t *testing.T) {
258-
s := factory(t)
251+
s := newStore(t)
259252
kb := knowledgeBase{s: s}
260253

261254
// Setup initial state
@@ -355,10 +348,9 @@ func TestErrorHandling(t *testing.T) {
355348
}
356349
})
357350

358-
factories := getStoreFactories()
359-
for name, factory := range factories {
351+
for name, newStore := range stores() {
360352
t.Run(fmt.Sprintf("AddObservationToNonExistentEntity_%s", name), func(t *testing.T) {
361-
s := factory(t)
353+
s := newStore(t)
362354
kb := knowledgeBase{s: s}
363355

364356
// Setup valid entity for comparison
@@ -385,11 +377,9 @@ func TestErrorHandling(t *testing.T) {
385377

386378
// TestFileFormatting verifies the JSON storage format structure.
387379
func TestFileFormatting(t *testing.T) {
388-
factories := getStoreFactories()
389-
390-
for name, factory := range factories {
380+
for name, newStore := range stores() {
391381
t.Run(name, func(t *testing.T) {
392-
s := factory(t)
382+
s := newStore(t)
393383
kb := knowledgeBase{s: s}
394384

395385
// Setup test entity
@@ -438,11 +428,9 @@ func TestFileFormatting(t *testing.T) {
438428

439429
// TestMCPServerIntegration tests the knowledge base through MCP server layer.
440430
func TestMCPServerIntegration(t *testing.T) {
441-
factories := getStoreFactories()
442-
443-
for name, factory := range factories {
431+
for name, newStore := range stores() {
444432
t.Run(name, func(t *testing.T) {
445-
s := factory(t)
433+
s := newStore(t)
446434
kb := knowledgeBase{s: s}
447435

448436
// Create mock server session
@@ -639,11 +627,9 @@ func TestMCPServerIntegration(t *testing.T) {
639627

640628
// TestMCPErrorHandling tests error scenarios through MCP layer.
641629
func TestMCPErrorHandling(t *testing.T) {
642-
factories := getStoreFactories()
643-
644-
for name, factory := range factories {
630+
for name, newStore := range stores() {
645631
t.Run(name, func(t *testing.T) {
646-
s := factory(t)
632+
s := newStore(t)
647633
kb := knowledgeBase{s: s}
648634

649635
ctx := context.Background()
@@ -661,15 +647,15 @@ func TestMCPErrorHandling(t *testing.T) {
661647
},
662648
}
663649

664-
obsResult, err := kb.AddObservations(ctx, serverSession, addObsParams)
665-
if err != nil {
666-
t.Fatalf("MCP AddObservations call failed: %v", err)
667-
}
668-
if !obsResult.IsError {
650+
_, err := kb.AddObservations(ctx, serverSession, addObsParams)
651+
if err == nil {
669652
t.Errorf("expected MCP AddObservations to return error for non-existent entity")
670-
}
671-
if len(obsResult.Content) == 0 {
672-
t.Errorf("expected error content in MCP response")
653+
} else {
654+
// Verify the error message contains expected text
655+
expectedErrorMsg := "entity with name NonExistentEntity not found"
656+
if !strings.Contains(err.Error(), expectedErrorMsg) {
657+
t.Errorf("expected error message to contain '%s', got: %v", expectedErrorMsg, err)
658+
}
673659
}
674660
})
675661
}

examples/memory/main.go

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,34 +24,6 @@ type HiArgs struct {
2424
Name string `json:"name"`
2525
}
2626

27-
// Entity represents a knowledge graph node with observations.
28-
type Entity struct {
29-
Name string `json:"name"`
30-
EntityType string `json:"entityType"`
31-
Observations []string `json:"observations"`
32-
}
33-
34-
// Relation represents a directed edge between two entities.
35-
type Relation struct {
36-
From string `json:"from"`
37-
To string `json:"to"`
38-
RelationType string `json:"relationType"`
39-
}
40-
41-
// Observation contains facts about an entity.
42-
type Observation struct {
43-
EntityName string `json:"entityName"`
44-
Contents []string `json:"contents"`
45-
46-
Observations []string `json:"observations,omitempty"` // Used for deletion operations
47-
}
48-
49-
// KnowledgeGraph represents the complete graph structure.
50-
type KnowledgeGraph struct {
51-
Entities []Entity `json:"entities"`
52-
Relations []Relation `json:"relations"`
53-
}
54-
5527
// CreateEntitiesArgs defines the create entities tool parameters.
5628
type CreateEntitiesArgs struct {
5729
Entities []Entity `json:"entities" mcp:"entities to create"`

0 commit comments

Comments
 (0)