Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 40 additions & 48 deletions examples/memory/kb.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,34 @@ import (
"github.com/modelcontextprotocol/go-sdk/mcp"
)

// Entity represents a knowledge graph node with observations.
type Entity struct {
Name string `json:"name"`
EntityType string `json:"entityType"`
Observations []string `json:"observations"`
}

// Relation represents a directed edge between two entities.
type Relation struct {
From string `json:"from"`
To string `json:"to"`
RelationType string `json:"relationType"`
}

// Observation contains facts about an entity.
type Observation struct {
EntityName string `json:"entityName"`
Contents []string `json:"contents"`

Observations []string `json:"observations,omitempty"` // Used for deletion operations
}

// KnowledgeGraph represents the complete graph structure.
type KnowledgeGraph struct {
Entities []Entity `json:"entities"`
Relations []Relation `json:"relations"`
}

// store provides persistence interface for knowledge base data.
type store interface {
Read() ([]byte, error)
Expand Down Expand Up @@ -154,7 +182,7 @@ func (k knowledgeBase) saveGraph(graph KnowledgeGraph) error {
}

// createEntities adds new entities to the graph, skipping duplicates by name.
// Returns the new entities that were actually added.
// It returns the new entities that were actually added.
func (k knowledgeBase) createEntities(entities []Entity) ([]Entity, error) {
graph, err := k.loadGraph()
if err != nil {
Expand All @@ -177,7 +205,7 @@ func (k knowledgeBase) createEntities(entities []Entity) ([]Entity, error) {
}

// createRelations adds new relations to the graph, skipping exact duplicates.
// Returns the new relations that were actually added.
// It returns the new relations that were actually added.
func (k knowledgeBase) createRelations(relations []Relation) ([]Relation, error) {
graph, err := k.loadGraph()
if err != nil {
Expand Down Expand Up @@ -205,7 +233,7 @@ func (k knowledgeBase) createRelations(relations []Relation) ([]Relation, error)
}

// addObservations appends new observations to existing entities.
// Returns the new observations that were actually added.
// It returns the new observations that were actually added.
func (k knowledgeBase) addObservations(observations []Observation) ([]Observation, error) {
graph, err := k.loadGraph()
if err != nil {
Expand Down Expand Up @@ -408,11 +436,7 @@ func (k knowledgeBase) CreateEntities(ctx context.Context, ss *mcp.ServerSession

entities, err := k.createEntities(params.Arguments.Entities)
if err != nil {
res.IsError = true
res.Content = []mcp.Content{
&mcp.TextContent{Text: err.Error()},
}
return &res, nil
return nil, err
}

res.Content = []mcp.Content{
Expand All @@ -431,11 +455,7 @@ func (k knowledgeBase) CreateRelations(ctx context.Context, ss *mcp.ServerSessio

relations, err := k.createRelations(params.Arguments.Relations)
if err != nil {
res.IsError = true
res.Content = []mcp.Content{
&mcp.TextContent{Text: err.Error()},
}
return &res, nil
return nil, err
}

res.Content = []mcp.Content{
Expand All @@ -454,11 +474,7 @@ func (k knowledgeBase) AddObservations(ctx context.Context, ss *mcp.ServerSessio

observations, err := k.addObservations(params.Arguments.Observations)
if err != nil {
res.IsError = true
res.Content = []mcp.Content{
&mcp.TextContent{Text: err.Error()},
}
return &res, nil
return nil, err
}

res.Content = []mcp.Content{
Expand All @@ -477,11 +493,7 @@ func (k knowledgeBase) DeleteEntities(ctx context.Context, ss *mcp.ServerSession

err := k.deleteEntities(params.Arguments.EntityNames)
if err != nil {
res.IsError = true
res.Content = []mcp.Content{
&mcp.TextContent{Text: err.Error()},
}
return &res, nil
return nil, err
}

res.Content = []mcp.Content{
Expand All @@ -496,11 +508,7 @@ func (k knowledgeBase) DeleteObservations(ctx context.Context, ss *mcp.ServerSes

err := k.deleteObservations(params.Arguments.Deletions)
if err != nil {
res.IsError = true
res.Content = []mcp.Content{
&mcp.TextContent{Text: err.Error()},
}
return &res, nil
return nil, err
}

res.Content = []mcp.Content{
Expand All @@ -515,11 +523,7 @@ func (k knowledgeBase) DeleteRelations(ctx context.Context, ss *mcp.ServerSessio

err := k.deleteRelations(params.Arguments.Relations)
if err != nil {
res.IsError = true
res.Content = []mcp.Content{
&mcp.TextContent{Text: err.Error()},
}
return &res, nil
return nil, err
}

res.Content = []mcp.Content{
Expand All @@ -534,11 +538,7 @@ func (k knowledgeBase) ReadGraph(ctx context.Context, ss *mcp.ServerSession, par

graph, err := k.loadGraph()
if err != nil {
res.IsError = true
res.Content = []mcp.Content{
&mcp.TextContent{Text: err.Error()},
}
return &res, nil
return nil, err
}

res.Content = []mcp.Content{
Expand All @@ -554,11 +554,7 @@ func (k knowledgeBase) SearchNodes(ctx context.Context, ss *mcp.ServerSession, p

graph, err := k.searchNodes(params.Arguments.Query)
if err != nil {
res.IsError = true
res.Content = []mcp.Content{
&mcp.TextContent{Text: err.Error()},
}
return &res, nil
return nil, err
}

res.Content = []mcp.Content{
Expand All @@ -574,11 +570,7 @@ func (k knowledgeBase) OpenNodes(ctx context.Context, ss *mcp.ServerSession, par

graph, err := k.openNodes(params.Arguments.Names)
if err != nil {
res.IsError = true
res.Content = []mcp.Content{
&mcp.TextContent{Text: err.Error()},
}
return &res, nil
return nil, err
}

res.Content = []mcp.Content{
Expand Down
84 changes: 35 additions & 49 deletions examples/memory/kb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ import (
"os"
"path/filepath"
"reflect"
"slices"
"strings"
"testing"

"github.com/modelcontextprotocol/go-sdk/mcp"
)

// getStoreFactories provides test factories for both storage implementations.
func getStoreFactories() map[string]func(t *testing.T) store {
// stores provides test factories for both storage implementations.
func stores() map[string]func(t *testing.T) store {
return map[string]func(t *testing.T) store{
"file": func(t *testing.T) store {
tempDir, err := os.MkdirTemp("", "kb-test-file-*")
Expand All @@ -35,11 +37,9 @@ func getStoreFactories() map[string]func(t *testing.T) store {

// TestKnowledgeBaseOperations verifies CRUD operations work correctly.
func TestKnowledgeBaseOperations(t *testing.T) {
factories := getStoreFactories()

for name, factory := range factories {
for name, newStore := range stores() {
t.Run(name, func(t *testing.T) {
s := factory(t)
s := newStore(t)
kb := knowledgeBase{s: s}

// Verify empty graph loads correctly
Expand Down Expand Up @@ -147,19 +147,16 @@ func TestKnowledgeBaseOperations(t *testing.T) {

// Confirm observation removal
graph, _ = kb.loadGraph()
aliceFound := false
for _, e := range graph.Entities {
if e.Name == "Alice" {
aliceFound = true
for _, obs := range e.Observations {
if obs == "Works as developer" {
t.Errorf("observation 'Works as developer' should have been deleted")
}
}
}
}
if !aliceFound {
aliceIndex := slices.IndexFunc(graph.Entities, func(e Entity) bool {
return e.Name == "Alice"
})
if aliceIndex == -1 {
t.Errorf("entity 'Alice' not found after deleting observation")
} else {
alice := graph.Entities[aliceIndex]
if slices.Contains(alice.Observations, "Works as developer") {
t.Errorf("observation 'Works as developer' should have been deleted")
}
}

// Remove relations
Expand Down Expand Up @@ -191,11 +188,9 @@ func TestKnowledgeBaseOperations(t *testing.T) {

// TestSaveAndLoadGraph ensures data persists correctly across save/load cycles.
func TestSaveAndLoadGraph(t *testing.T) {
factories := getStoreFactories()

for name, factory := range factories {
for name, newStore := range stores() {
t.Run(name, func(t *testing.T) {
s := factory(t)
s := newStore(t)
kb := knowledgeBase{s: s}

// Setup test data
Expand Down Expand Up @@ -251,11 +246,9 @@ func TestSaveAndLoadGraph(t *testing.T) {

// TestDuplicateEntitiesAndRelations verifies duplicate prevention logic.
func TestDuplicateEntitiesAndRelations(t *testing.T) {
factories := getStoreFactories()

for name, factory := range factories {
for name, newStore := range stores() {
t.Run(name, func(t *testing.T) {
s := factory(t)
s := newStore(t)
kb := knowledgeBase{s: s}

// Setup initial state
Expand Down Expand Up @@ -355,10 +348,9 @@ func TestErrorHandling(t *testing.T) {
}
})

factories := getStoreFactories()
for name, factory := range factories {
for name, newStore := range stores() {
t.Run(fmt.Sprintf("AddObservationToNonExistentEntity_%s", name), func(t *testing.T) {
s := factory(t)
s := newStore(t)
kb := knowledgeBase{s: s}

// Setup valid entity for comparison
Expand All @@ -385,11 +377,9 @@ func TestErrorHandling(t *testing.T) {

// TestFileFormatting verifies the JSON storage format structure.
func TestFileFormatting(t *testing.T) {
factories := getStoreFactories()

for name, factory := range factories {
for name, newStore := range stores() {
t.Run(name, func(t *testing.T) {
s := factory(t)
s := newStore(t)
kb := knowledgeBase{s: s}

// Setup test entity
Expand Down Expand Up @@ -438,11 +428,9 @@ func TestFileFormatting(t *testing.T) {

// TestMCPServerIntegration tests the knowledge base through MCP server layer.
func TestMCPServerIntegration(t *testing.T) {
factories := getStoreFactories()

for name, factory := range factories {
for name, newStore := range stores() {
t.Run(name, func(t *testing.T) {
s := factory(t)
s := newStore(t)
kb := knowledgeBase{s: s}

// Create mock server session
Expand Down Expand Up @@ -639,11 +627,9 @@ func TestMCPServerIntegration(t *testing.T) {

// TestMCPErrorHandling tests error scenarios through MCP layer.
func TestMCPErrorHandling(t *testing.T) {
factories := getStoreFactories()

for name, factory := range factories {
for name, newStore := range stores() {
t.Run(name, func(t *testing.T) {
s := factory(t)
s := newStore(t)
kb := knowledgeBase{s: s}

ctx := context.Background()
Expand All @@ -661,15 +647,15 @@ func TestMCPErrorHandling(t *testing.T) {
},
}

obsResult, err := kb.AddObservations(ctx, serverSession, addObsParams)
if err != nil {
t.Fatalf("MCP AddObservations call failed: %v", err)
}
if !obsResult.IsError {
_, err := kb.AddObservations(ctx, serverSession, addObsParams)
if err == nil {
t.Errorf("expected MCP AddObservations to return error for non-existent entity")
}
if len(obsResult.Content) == 0 {
t.Errorf("expected error content in MCP response")
} else {
// Verify the error message contains expected text
expectedErrorMsg := "entity with name NonExistentEntity not found"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just want is fine

if !strings.Contains(err.Error(), expectedErrorMsg) {
t.Errorf("expected error message to contain '%s', got: %v", expectedErrorMsg, err)
}
}
})
}
Expand Down
28 changes: 0 additions & 28 deletions examples/memory/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,34 +24,6 @@ type HiArgs struct {
Name string `json:"name"`
}

// Entity represents a knowledge graph node with observations.
type Entity struct {
Name string `json:"name"`
EntityType string `json:"entityType"`
Observations []string `json:"observations"`
}

// Relation represents a directed edge between two entities.
type Relation struct {
From string `json:"from"`
To string `json:"to"`
RelationType string `json:"relationType"`
}

// Observation contains facts about an entity.
type Observation struct {
EntityName string `json:"entityName"`
Contents []string `json:"contents"`

Observations []string `json:"observations,omitempty"` // Used for deletion operations
}

// KnowledgeGraph represents the complete graph structure.
type KnowledgeGraph struct {
Entities []Entity `json:"entities"`
Relations []Relation `json:"relations"`
}

// CreateEntitiesArgs defines the create entities tool parameters.
type CreateEntitiesArgs struct {
Entities []Entity `json:"entities" mcp:"entities to create"`
Expand Down
Loading