Skip to content

Commit a95138f

Browse files
committed
more tests
1 parent 3b7622c commit a95138f

File tree

2 files changed

+178
-24
lines changed

2 files changed

+178
-24
lines changed

src/ManagedCode.GraphRag.Postgres/ApacheAge/AgeClient.cs

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -150,35 +150,49 @@ public async Task ExecuteCypherAsync(string graph, string cypher, CancellationTo
150150
ArgumentException.ThrowIfNullOrWhiteSpace(cypher);
151151
CheckForExistingConnection();
152152

153-
await using var command = new NpgsqlCommand(
154-
$"SELECT * FROM ag_catalog.cypher('{graph}', $$ {cypher} $$) as (result ag_catalog.agtype);",
155-
_connection);
153+
var commandText = $"SELECT * FROM ag_catalog.cypher('{graph}', $$ {cypher} $$) as (result ag_catalog.agtype);";
156154

157-
try
158-
{
159-
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
160-
LogMessages.CypherExecuted(_logger, graph, cypher);
161-
}
162-
catch (PostgresException ex)
155+
for (var attempt = 0; attempt < 2; attempt++)
163156
{
164-
LogMessages.CypherExecutionError(
165-
_logger,
166-
$"Graph: {graph}. {ex.MessageText}",
167-
cypher,
168-
ex);
169-
throw new AgeException($"Could not execute Cypher command. Graph: {graph}. Cypher: {cypher}", ex);
170-
}
171-
catch (Exception ex)
172-
{
173-
LogMessages.CypherExecutionError(
174-
_logger,
175-
$"Graph: {graph}. {ex.Message}",
176-
cypher,
177-
ex);
178-
throw new AgeException($"Could not execute Cypher command. Graph: {graph}. Cypher: {cypher}", ex);
157+
await using var command = new NpgsqlCommand(commandText, _connection);
158+
159+
try
160+
{
161+
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
162+
LogMessages.CypherExecuted(_logger, graph, cypher);
163+
return;
164+
}
165+
catch (PostgresException ex) when (ShouldRetryOnLabelCreationRace(ex) && attempt == 0)
166+
{
167+
// AGE sometimes races on label/relation creation under parallel writes. Retry once.
168+
continue;
169+
}
170+
catch (PostgresException ex)
171+
{
172+
LogMessages.CypherExecutionError(
173+
_logger,
174+
$"Graph: {graph}. {ex.MessageText}",
175+
cypher,
176+
ex);
177+
throw new AgeException($"Could not execute Cypher command. Graph: {graph}. Cypher: {cypher}", ex);
178+
}
179+
catch (Exception ex)
180+
{
181+
LogMessages.CypherExecutionError(
182+
_logger,
183+
$"Graph: {graph}. {ex.Message}",
184+
cypher,
185+
ex);
186+
throw new AgeException($"Could not execute Cypher command. Graph: {graph}. Cypher: {cypher}", ex);
187+
}
179188
}
180189
}
181190

191+
private static bool ShouldRetryOnLabelCreationRace(PostgresException exception) =>
192+
exception.SqlState is PostgresErrorCodes.DuplicateTable or
193+
PostgresErrorCodes.DuplicateObject or
194+
PostgresErrorCodes.UniqueViolation;
195+
182196
public async Task<AgeDataReader> ExecuteQueryAsync(string query, CancellationToken cancellationToken = default, params object?[] parameters)
183197
{
184198
ArgumentException.ThrowIfNullOrWhiteSpace(query);

tests/ManagedCode.GraphRag.Tests/Storage/Postgres/PostgresAgtypeParameterTests.cs

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,135 @@ public async Task GraphStore_UpsertNodes_WithStringProperties_UsesAgtypeParamete
110110
Assert.Equal("mentorship", stored.Properties["category"]?.ToString());
111111
}
112112

113+
[Fact]
114+
public async Task GraphStore_RejectsInjectionAttempts_InPropertyValues()
115+
{
116+
var store = _fixture.Services.GetKeyedService<IGraphStore>("postgres");
117+
Assert.NotNull(store);
118+
await store!.InitializeAsync();
119+
120+
var label = GraphStoreTestProviders.GetLabel("postgres");
121+
var sentinelId = $"postgres-sentinel-{Guid.NewGuid():N}";
122+
var attackerId = $"postgres-inject-{Guid.NewGuid():N}";
123+
124+
// Baseline node that must survive any attempted injection.
125+
await store.UpsertNodeAsync(sentinelId, label, new Dictionary<string, object?> { ["name"] = "sentinel" });
126+
127+
var injectionPayload = "alice'); MATCH (n) DETACH DELETE n; //";
128+
129+
await store.UpsertNodeAsync(attackerId, label, new Dictionary<string, object?>
130+
{
131+
["name"] = injectionPayload,
132+
["role"] = "attacker"
133+
});
134+
135+
var nodes = await CollectAsync(store.GetNodesAsync());
136+
Assert.Contains(nodes, n => n.Id == sentinelId && n.Properties["name"]?.ToString() == "sentinel");
137+
var injected = Assert.Single(nodes, n => n.Id == attackerId);
138+
Assert.Equal(injectionPayload, injected.Properties["name"]?.ToString());
139+
Assert.Equal("attacker", injected.Properties["role"]?.ToString());
140+
}
141+
142+
[Fact]
143+
public async Task GraphStore_RejectsInjectionAttempts_InIds()
144+
{
145+
var store = _fixture.Services.GetKeyedService<IGraphStore>("postgres");
146+
Assert.NotNull(store);
147+
await store!.InitializeAsync();
148+
149+
var label = GraphStoreTestProviders.GetLabel("postgres");
150+
var safeId = $"postgres-safe-{Guid.NewGuid():N}";
151+
var dangerousId = $"danger-') RETURN 1 //";
152+
153+
await store.UpsertNodeAsync(safeId, label, new Dictionary<string, object?> { ["flag"] = "safe" });
154+
await store.UpsertNodeAsync(dangerousId, label, new Dictionary<string, object?> { ["flag"] = "danger" });
155+
156+
var nodes = await CollectAsync(store.GetNodesAsync());
157+
Assert.Contains(nodes, n => n.Id == safeId && n.Properties["flag"]?.ToString() == "safe");
158+
Assert.Contains(nodes, n => n.Id == dangerousId && n.Properties["flag"]?.ToString() == "danger");
159+
}
160+
161+
[Fact]
162+
public async Task DeleteNodes_DoesNotCascade_WhenIdsContainInjectionLikeContent()
163+
{
164+
var store = _fixture.Services.GetKeyedService<IGraphStore>("postgres");
165+
Assert.NotNull(store);
166+
await store!.InitializeAsync();
167+
168+
var label = GraphStoreTestProviders.GetLabel("postgres");
169+
var sentinelId = $"postgres-safe-{Guid.NewGuid():N}";
170+
var attackerId = "kill-all-nodes\") DETACH DELETE n //";
171+
172+
await store.UpsertNodeAsync(sentinelId, label, new Dictionary<string, object?> { ["flag"] = "safe" });
173+
await store.UpsertNodeAsync(attackerId, label, new Dictionary<string, object?> { ["flag"] = "danger" });
174+
175+
await store.DeleteNodesAsync(new[] { attackerId });
176+
177+
var nodes = await CollectAsync(store.GetNodesAsync());
178+
Assert.Contains(nodes, n => n.Id == sentinelId && n.Properties["flag"]?.ToString() == "safe");
179+
Assert.DoesNotContain(nodes, n => n.Id == attackerId);
180+
}
181+
182+
[Fact]
183+
public async Task UpsertNode_ThrowsOnInvalidLabelCharacters()
184+
{
185+
var store = _fixture.Services.GetKeyedService<IGraphStore>("postgres");
186+
Assert.NotNull(store);
187+
188+
var badLabel = "User) DETACH DELETE n";
189+
await Assert.ThrowsAsync<ArgumentException>(async () =>
190+
await store!.UpsertNodeAsync("id", badLabel, new Dictionary<string, object?>()));
191+
}
192+
193+
public static IEnumerable<object[]> InjectionStringPayloads => new[]
194+
{
195+
new object[] { "'; DROP SCHEMA public; --" },
196+
new object[] { "$$; SELECT 1; $$" },
197+
new object[] { "\") MATCH (n) DETACH DELETE n //" },
198+
new object[] { "alice\n); RETURN 1; //" },
199+
new object[] { "\"quoted\" with {{braces}} and ;" },
200+
new object[] { "unicode-rtl-\u202Epayload" }
201+
};
202+
203+
[Theory]
204+
[MemberData(nameof(InjectionStringPayloads))]
205+
public async Task GraphStore_RejectsInjectionAttempts_InProperties_WithVariousPayloads(string payload)
206+
{
207+
var store = _fixture.Services.GetKeyedService<IGraphStore>("postgres");
208+
Assert.NotNull(store);
209+
await store!.InitializeAsync();
210+
211+
var label = GraphStoreTestProviders.GetLabel("postgres");
212+
var sentinelId = $"postgres-sentinel-{Guid.NewGuid():N}";
213+
var attackerId = $"postgres-inject-{Guid.NewGuid():N}";
214+
215+
await store.UpsertNodeAsync(sentinelId, label, new Dictionary<string, object?> { ["name"] = "sentinel" });
216+
await store.UpsertNodeAsync(attackerId, label, new Dictionary<string, object?> { ["payload"] = payload });
217+
218+
var nodes = await CollectAsync(store.GetNodesAsync());
219+
Assert.Contains(nodes, n => n.Id == sentinelId && n.Properties["name"]?.ToString() == "sentinel");
220+
var injected = Assert.Single(nodes, n => n.Id == attackerId);
221+
Assert.Equal(payload, injected.Properties["payload"]?.ToString());
222+
}
223+
224+
[Fact]
225+
public async Task GraphStore_RejectsInjectionAttempts_InRelationshipTypes()
226+
{
227+
var store = _fixture.Services.GetKeyedService<IGraphStore>("postgres");
228+
Assert.NotNull(store);
229+
await store!.InitializeAsync();
230+
231+
var label = GraphStoreTestProviders.GetLabel("postgres");
232+
var src = $"postgres-rel-{Guid.NewGuid():N}";
233+
var dst = $"postgres-rel-{Guid.NewGuid():N}";
234+
await store!.UpsertNodeAsync(src, label, new Dictionary<string, object?>());
235+
await store.UpsertNodeAsync(dst, label, new Dictionary<string, object?>());
236+
237+
var badType = "BADTYPE'); MATCH (n) DETACH DELETE n; //";
238+
await Assert.ThrowsAsync<ArgumentException>(async () =>
239+
await store.UpsertRelationshipAsync(src, dst, badType, new Dictionary<string, object?> { ["score"] = 1 }));
240+
}
241+
113242
private static async Task<GraphNode?> FindNodeAsync(IGraphStore store, string nodeId, CancellationToken cancellationToken = default)
114243
{
115244
await foreach (var node in store.GetNodesAsync(cancellationToken: cancellationToken))
@@ -122,4 +251,15 @@ public async Task GraphStore_UpsertNodes_WithStringProperties_UsesAgtypeParamete
122251

123252
return null;
124253
}
254+
255+
private static async Task<List<GraphNode>> CollectAsync(IAsyncEnumerable<GraphNode> source)
256+
{
257+
var list = new List<GraphNode>();
258+
await foreach (var item in source)
259+
{
260+
list.Add(item);
261+
}
262+
263+
return list;
264+
}
125265
}

0 commit comments

Comments
 (0)