Skip to content

Commit b691110

Browse files
committed
Update CreateMemories to support linking on creation
1 parent 70dbe20 commit b691110

File tree

6 files changed

+129
-24
lines changed

6 files changed

+129
-24
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using Microsoft.Data.Sqlite;
2+
3+
namespace KnowledgeBaseServer.Extensions;
4+
5+
public static class SqliteExceptionExtensions
6+
{
7+
public static bool IsForeignKeyConstraintViolation(this SqliteException ex) =>
8+
ex is { SqliteErrorCode: 19, SqliteExtendedErrorCode: 787 };
9+
10+
public static bool IsPrimaryKeyConstraintViolation(this SqliteException ex) =>
11+
ex is { SqliteErrorCode: 19, SqliteExtendedErrorCode: 1555 };
12+
}

src/KnowledgeBaseServer/StringExtensions.cs renamed to src/KnowledgeBaseServer/Extensions/StringExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using System.Linq;
22

3-
namespace KnowledgeBaseServer;
3+
namespace KnowledgeBaseServer.Extensions;
44

55
public static class StringExtensions
66
{
Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
using System;
2+
using System.Collections.Generic;
23
using System.ComponentModel;
4+
using System.Data;
35
using System.Linq;
46
using Dapper;
7+
using KnowledgeBaseServer.Extensions;
58
using Microsoft.Data.Sqlite;
69
using ModelContextProtocol.Server;
710

@@ -21,44 +24,53 @@ public static class ConnectMemoriesTool
2124
public static string Handle(
2225
ConnectionString connectionString,
2326
[Description("Id of the parent memory (the older memory).")] Guid parentMemoryId,
24-
[Description("Id of the child memories (the newer memories).")] Guid[] childMemories
27+
[Description("Id of the child memories (the newer memories).")] Guid[] childMemoryIds
2528
)
2629
{
27-
var now = DateTimeOffset.UtcNow;
28-
var data = childMemories
29-
.Select(id => new
30-
{
31-
FromMemoryId = id,
32-
ToMemoryId = parentMemoryId,
33-
Created = now,
34-
})
35-
.ToArray();
36-
3730
using var connection = connectionString.CreateConnection();
3831
using var transaction = connection.BeginTransaction();
3932

4033
try
4134
{
42-
connection.Execute(
43-
sql: """
44-
insert into memory_links (from_memory_id, to_memory_id, created) values
45-
(@FromMemoryId, @ToMemoryId, @Created)
46-
""",
47-
data
48-
);
35+
_ = connection.ConnectMemoriesInternal(transaction, parentMemoryId, childMemoryIds);
4936
transaction.Commit();
5037
}
51-
// FK constraint violation
52-
catch (SqliteException ex) when (ex is { SqliteErrorCode: 19, SqliteExtendedErrorCode: 787 })
38+
catch (SqliteException ex) when (ex.IsForeignKeyConstraintViolation())
5339
{
5440
return "Invalid memory ids provided.";
5541
}
56-
// PK constraint violation
57-
catch (SqliteException ex) when (ex is { SqliteErrorCode: 19, SqliteExtendedErrorCode: 1555 })
42+
catch (SqliteException ex) when (ex.IsPrimaryKeyConstraintViolation())
5843
{
5944
return "Some of the requested memories are already linked.";
6045
}
6146

6247
return "Memories linked successfully.";
6348
}
49+
50+
internal static int ConnectMemoriesInternal(
51+
this IDbConnection connection,
52+
IDbTransaction transaction,
53+
Guid parentMemoryId,
54+
IEnumerable<Guid> childMemoryIds,
55+
DateTimeOffset? now = null
56+
)
57+
{
58+
now ??= DateTimeOffset.UtcNow;
59+
60+
var data = childMemoryIds.Select(id => new
61+
{
62+
FromMemoryId = id,
63+
ToMemoryId = parentMemoryId,
64+
Created = now,
65+
});
66+
67+
return connection.Execute(
68+
sql: """
69+
insert into memory_links (from_memory_id, to_memory_id, created) values
70+
(@FromMemoryId, @ToMemoryId, @Created)
71+
""",
72+
data,
73+
transaction
74+
);
75+
}
6476
}

src/KnowledgeBaseServer/Tools/CreateMemoriesTool.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
using System.Text.Json;
55
using Dapper;
66
using KnowledgeBaseServer.Dtos;
7+
using KnowledgeBaseServer.Extensions;
8+
using Microsoft.Data.Sqlite;
79
using ModelContextProtocol.Server;
810

911
namespace KnowledgeBaseServer.Tools;
@@ -24,7 +26,8 @@ public static string Handle(
2426
JsonSerializerOptions jsonSerializerOptions,
2527
[Description("The topic to use for the memories.")] string topic,
2628
[Description("The text of the memories.")] string[] memories,
27-
[Description("Optional information to provide context for these memories.")] string? context = null
29+
[Description("Optional information to provide context for these memories.")] string? context = null,
30+
[Description("Optionally connect the new memories to an existing parent memory.")] Guid? parentMemoryId = null
2831
)
2932
{
3033
var now = DateTimeOffset.UtcNow;
@@ -109,6 +112,23 @@ insert into memory_search (memory_id, content, context) values
109112
transaction
110113
);
111114

115+
if (parentMemoryId is not null)
116+
{
117+
try
118+
{
119+
_ = connection.ConnectMemoriesInternal(
120+
transaction,
121+
parentMemoryId.Value,
122+
createdMemories.Select(m => m.Id),
123+
now
124+
);
125+
}
126+
catch (SqliteException ex) when (ex.IsForeignKeyConstraintViolation())
127+
{
128+
return $"Invalid {nameof(parentMemoryId)} provided.";
129+
}
130+
}
131+
112132
transaction.Commit();
113133

114134
return JsonSerializer.Serialize(

src/KnowledgeBaseServer/Tools/SearchMemoryTool.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Text.Json;
55
using Dapper;
66
using KnowledgeBaseServer.Dtos;
7+
using KnowledgeBaseServer.Extensions;
78
using ModelContextProtocol.Server;
89

910
namespace KnowledgeBaseServer.Tools;

tests/KnowledgeBaseServer.Tests/Tools/CreateMemoriesToolTests.cs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
using System.Diagnostics;
12
using System.Linq;
23
using System.Text.Json;
34
using Bogus;
45
using Dapper;
56
using KnowledgeBaseServer.Dtos;
7+
using KnowledgeBaseServer.Extensions;
68
using KnowledgeBaseServer.Tests.Data;
79
using KnowledgeBaseServer.Tools;
810
using Shouldly;
@@ -82,6 +84,7 @@ public void ShouldCreateMemoryWithoutContext_WhenContextIsNull()
8284
// assert
8385
using var connection = ConnectionString.CreateConnection();
8486
connection.GetTopics().ShouldHaveSingleItem().Name.ShouldBe(topic);
87+
connection.GetMemoryContexts().ShouldBeEmpty();
8588
connection.GetMemoryLinks().ShouldBeEmpty();
8689
connection.GetMemories().ShouldHaveSingleItem().Content.ShouldBe(expectedMemory);
8790
}
@@ -127,6 +130,63 @@ where content match @Word
127130
);
128131
}
129132

133+
[Fact]
134+
public void ShouldReturnError_WhenParentMemoryIdIsNotValid()
135+
{
136+
// arrange
137+
138+
// act
139+
var result = CreateMemoriesTool.Handle(
140+
ConnectionString,
141+
JsonSerializerOptions.Default,
142+
_faker.Lorem.Sentence(),
143+
[_faker.Lorem.Sentence()],
144+
parentMemoryId: _faker.Random.Guid()
145+
);
146+
147+
// assert
148+
result.ShouldBe("Invalid parentMemoryId provided.");
149+
using var connection = ConnectionString.CreateConnection();
150+
connection.GetTopics().ShouldBeEmpty();
151+
connection.GetMemoryContexts().ShouldBeEmpty();
152+
connection.GetMemories().ShouldBeEmpty();
153+
connection.GetMemoryLinks().ShouldBeEmpty();
154+
}
155+
156+
[Fact]
157+
public void ShouldLinkNewMemories_WhenParentMemoryIdProvided()
158+
{
159+
// arrange
160+
var parentMemories = JsonSerializer.Deserialize<CreatedMemoryDto[]>(
161+
CreateMemoriesTool.Handle(
162+
ConnectionString,
163+
JsonSerializerOptions.Default,
164+
_faker.Lorem.Sentence(),
165+
[_faker.Lorem.Sentence()]
166+
)
167+
);
168+
Debug.Assert(parentMemories is { Length: 1 });
169+
var parentMemoryId = parentMemories[0].Id;
170+
171+
// act
172+
var childMemories = JsonSerializer.Deserialize<CreatedMemoryDto[]>(
173+
CreateMemoriesTool.Handle(
174+
ConnectionString,
175+
JsonSerializerOptions.Default,
176+
_faker.Lorem.Sentence(),
177+
[_faker.Lorem.Sentence()],
178+
parentMemoryId: parentMemoryId
179+
)
180+
);
181+
182+
// assert
183+
var childMemoryId = childMemories.ShouldNotBeNull().ShouldHaveSingleItem().Id;
184+
using var connection = ConnectionString.CreateConnection();
185+
var memoryLink = connection.GetMemoryLinks().ShouldHaveSingleItem();
186+
memoryLink.FromMemoryId.ShouldBe(childMemoryId);
187+
memoryLink.ToMemoryId.ShouldBe(parentMemoryId);
188+
}
189+
130190
[Fact]
131191
public void ShouldReturnNewMemoriesWithIds()
132192
{

0 commit comments

Comments
 (0)