Skip to content

Commit dc5d14d

Browse files
committed
tests
1 parent acb8207 commit dc5d14d

File tree

4 files changed

+278
-3
lines changed

4 files changed

+278
-3
lines changed

Directory.Build.props

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
<RepositoryUrl>https://github.com/managedcode/graphrag</RepositoryUrl>
2626
<PackageProjectUrl>https://github.com/managedcode/graphrag</PackageProjectUrl>
2727
<Product>Managed Code GraphRag</Product>
28-
<Version>0.0.2</Version>
29-
<PackageVersion>0.0.2</PackageVersion>
28+
<Version>0.0.3</Version>
29+
<PackageVersion>0.0.3</PackageVersion>
3030

3131
</PropertyGroup>
3232
<PropertyGroup Condition="'$(GITHUB_ACTIONS)' == 'true'">
@@ -42,7 +42,7 @@
4242
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
4343
</PackageReference>
4444
</ItemGroup>
45-
45+
4646
<ItemGroup>
4747
<PackageReference Update="Microsoft.SourceLink.GitHub" Version="8.0.0" />
4848
</ItemGroup>

tests/ManagedCode.GraphRag.Tests/Cache/MemoryPipelineCacheTests.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using System.Collections.Concurrent;
2+
using System.Reflection;
13
using GraphRag.Cache;
24
using Microsoft.Extensions.Caching.Memory;
35

@@ -58,4 +60,45 @@ public async Task ClearAsync_RemovesChildEntries()
5860
Assert.False(await parent.HasAsync("parentValue"));
5961
Assert.False(await child.HasAsync("childValue"));
6062
}
63+
64+
[Fact]
65+
public async Task DeleteAsync_RemovesTrackedKeyEvenWithDebugData()
66+
{
67+
var memoryCache = new MemoryCache(new MemoryCacheOptions());
68+
var cache = new MemoryPipelineCache(memoryCache);
69+
70+
await cache.SetAsync("debug", 123, new Dictionary<string, object?> { ["token"] = "value" });
71+
var keys = GetTrackedKeys(cache);
72+
Assert.Contains(keys.Keys, key => key.EndsWith(":debug", StringComparison.Ordinal));
73+
74+
await cache.DeleteAsync("debug");
75+
76+
Assert.DoesNotContain(GetTrackedKeys(cache).Keys, key => key.EndsWith(":debug", StringComparison.Ordinal));
77+
Assert.False(await cache.HasAsync("debug"));
78+
}
79+
80+
[Fact]
81+
public async Task CreateChild_AfterParentWrites_StillClearsChildEntries()
82+
{
83+
var memoryCache = new MemoryCache(new MemoryCacheOptions());
84+
var parent = new MemoryPipelineCache(memoryCache);
85+
86+
await parent.SetAsync("root", "root");
87+
var child = parent.CreateChild("later-child");
88+
await child.SetAsync("inner", "child");
89+
90+
await parent.ClearAsync();
91+
92+
Assert.False(await parent.HasAsync("root"));
93+
Assert.False(await child.HasAsync("inner"));
94+
}
95+
96+
private static ConcurrentDictionary<string, byte> GetTrackedKeys(MemoryPipelineCache cache)
97+
{
98+
var field = typeof(MemoryPipelineCache)
99+
.GetField("_keys", BindingFlags.NonPublic | BindingFlags.Instance)
100+
?? throw new InvalidOperationException("Could not access keys field.");
101+
102+
return (ConcurrentDictionary<string, byte>)field.GetValue(cache)!;
103+
}
61104
}

tests/ManagedCode.GraphRag.Tests/Integration/CommunitySummariesIntegrationTests.cs

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,105 @@ await outputStorage.WriteTableAsync(PipelineTableNames.Relationships, new[]
120120
Assert.True(File.Exists(Path.Combine(outputDir, $"{PipelineTableNames.CommunityReports}.json")));
121121
}
122122

123+
[Fact]
124+
public async Task CommunitySummariesWorkflow_PrefersManualOverAutoPrompts()
125+
{
126+
var outputDir = Path.Combine(_rootDir, "output-auto");
127+
var inputDir = Path.Combine(_rootDir, "input-auto");
128+
var previousDir = Path.Combine(_rootDir, "previous-auto");
129+
Directory.CreateDirectory(outputDir);
130+
Directory.CreateDirectory(inputDir);
131+
Directory.CreateDirectory(previousDir);
132+
133+
var manualDirectory = Path.Combine(_rootDir, "prompt_manual");
134+
var autoDirectory = Path.Combine(_rootDir, "prompt_auto");
135+
136+
var manualSystem = Path.Combine(manualDirectory, "index", "community_reports", "system.txt");
137+
Directory.CreateDirectory(Path.GetDirectoryName(manualSystem)!);
138+
File.WriteAllText(manualSystem, "Manual system override");
139+
140+
var autoSystem = Path.Combine(autoDirectory, "index", "community_reports", "system.txt");
141+
var autoUser = Path.Combine(autoDirectory, "index", "community_reports", "user.txt");
142+
Directory.CreateDirectory(Path.GetDirectoryName(autoSystem)!);
143+
File.WriteAllText(autoSystem, "Auto system value");
144+
File.WriteAllText(autoUser, "Auto template for {{entities}} within {{max_length}} characters.");
145+
146+
var outputStorage = new FilePipelineStorage(outputDir);
147+
await outputStorage.WriteTableAsync(PipelineTableNames.Entities, new[]
148+
{
149+
new EntityRecord("entity-1", 0, "Alice", "Person", "Investigator", new[] { "unit-1" }.ToImmutableArray(), 2, 1, 0, 0),
150+
new EntityRecord("entity-2", 1, "Eve", "Person", "Analyst", new[] { "unit-2" }.ToImmutableArray(), 1, 1, 0, 0)
151+
});
152+
153+
await outputStorage.WriteTableAsync(PipelineTableNames.Relationships, new[]
154+
{
155+
new RelationshipRecord("rel-1", 0, "Alice", "Eve", "collaborates_with", "Joint research", 0.7, 2, new[] { "unit-1" }.ToImmutableArray(), true)
156+
});
157+
158+
var capturedSystem = string.Empty;
159+
var capturedUser = string.Empty;
160+
var services = new ServiceCollection()
161+
.AddSingleton<IChatClient>(new TestChatClientFactory(messages =>
162+
{
163+
var system = messages.First(m => m.Role == ChatRole.System);
164+
var user = messages.First(m => m.Role == ChatRole.User);
165+
capturedSystem = system.Text ?? string.Empty;
166+
capturedUser = user.Text ?? string.Empty;
167+
return new ChatResponse(new ChatMessage(ChatRole.Assistant, "Combined summary"));
168+
}).CreateClient())
169+
.AddGraphRag()
170+
.BuildServiceProvider();
171+
172+
var config = new GraphRagConfig
173+
{
174+
RootDir = _rootDir,
175+
PromptTuning = new PromptTuningConfig
176+
{
177+
Manual = new ManualPromptTuningConfig
178+
{
179+
Enabled = true,
180+
Directory = "prompt_manual"
181+
},
182+
Auto = new AutoPromptTuningConfig
183+
{
184+
Enabled = true,
185+
Directory = "prompt_auto"
186+
}
187+
},
188+
CommunityReports = new CommunityReportsConfig
189+
{
190+
GraphPrompt = null,
191+
TextPrompt = null,
192+
MaxLength = 256
193+
}
194+
};
195+
196+
var context = new PipelineRunContext(
197+
inputStorage: new FilePipelineStorage(inputDir),
198+
outputStorage: outputStorage,
199+
previousStorage: new FilePipelineStorage(previousDir),
200+
cache: new StubPipelineCache(),
201+
callbacks: NoopWorkflowCallbacks.Instance,
202+
stats: new PipelineRunStats(),
203+
state: new PipelineState(),
204+
services: services);
205+
206+
var createCommunities = CreateCommunitiesWorkflow.Create();
207+
await createCommunities(config, context, CancellationToken.None);
208+
209+
var summaries = CommunitySummariesWorkflow.Create();
210+
await summaries(config, context, CancellationToken.None);
211+
212+
Assert.Equal("Manual system override", capturedSystem);
213+
Assert.Contains("Auto template", capturedUser, StringComparison.Ordinal);
214+
Assert.DoesNotContain("{{", capturedUser, StringComparison.Ordinal);
215+
216+
var reports = await outputStorage.LoadTableAsync<CommunityReportRecord>(PipelineTableNames.CommunityReports);
217+
var report = Assert.Single(reports);
218+
Assert.Equal("Combined summary", report.Summary);
219+
Assert.Equal(2, report.EntityTitles.Count);
220+
}
221+
123222
public void Dispose()
124223
{
125224
try

tests/ManagedCode.GraphRag.Tests/Runtime/PipelineExecutorTests.cs

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using GraphRag.Callbacks;
22
using GraphRag.Config;
33
using GraphRag.Indexing.Runtime;
4+
using GraphRag.Logging;
45
using GraphRag.Storage;
56
using ManagedCode.GraphRag.Tests.Infrastructure;
67
using Microsoft.Extensions.DependencyInjection;
@@ -75,4 +76,136 @@ public async Task ExecuteAsync_HonoursStopSignal()
7576
Assert.Single(outputs);
7677
Assert.Equal("first", outputs[0].Workflow);
7778
}
79+
80+
[Fact]
81+
public async Task ExecuteAsync_InvokesCallbacksAndUpdatesStats()
82+
{
83+
var services = new ServiceCollection().BuildServiceProvider();
84+
var callbacks = new RecordingCallbacks();
85+
var stats = new PipelineRunStats();
86+
var context = new PipelineRunContext(
87+
new MemoryPipelineStorage(),
88+
new MemoryPipelineStorage(),
89+
new MemoryPipelineStorage(),
90+
new StubPipelineCache(),
91+
callbacks,
92+
stats,
93+
new PipelineState(),
94+
services);
95+
96+
var pipeline = new WorkflowPipeline("stats", new[]
97+
{
98+
new WorkflowStep("first", async (cfg, ctx, token) =>
99+
{
100+
await Task.Delay(5, token);
101+
return new WorkflowResult("ok");
102+
}),
103+
new WorkflowStep("second", (cfg, ctx, token) => ValueTask.FromResult(new WorkflowResult("done")))
104+
});
105+
106+
var executor = new PipelineExecutor(new NullLogger<PipelineExecutor>());
107+
var results = new List<PipelineRunResult>();
108+
109+
await foreach (var result in executor.ExecuteAsync(pipeline, new GraphRagConfig(), context))
110+
{
111+
results.Add(result);
112+
}
113+
114+
Assert.Equal(new[] { "first", "second" }, callbacks.WorkflowStarts);
115+
Assert.Equal(callbacks.WorkflowStarts, callbacks.WorkflowEnds);
116+
Assert.Equal(2, callbacks.PipelineEndResults?.Count);
117+
Assert.True(callbacks.PipelineStartedWith?.SequenceEqual(pipeline.Names));
118+
119+
Assert.Equal(2, results.Count);
120+
Assert.All(results, r => Assert.Null(r.Errors));
121+
122+
Assert.True(stats.TotalRuntime >= 0);
123+
Assert.True(stats.Workflows.ContainsKey("first"));
124+
Assert.True(stats.Workflows["first"].ContainsKey("overall"));
125+
Assert.True(stats.Workflows.ContainsKey("second"));
126+
Assert.True(stats.Workflows["second"].ContainsKey("overall"));
127+
}
128+
129+
[Fact]
130+
public async Task ExecuteAsync_RecordsExceptionInResultsAndStats()
131+
{
132+
var services = new ServiceCollection().BuildServiceProvider();
133+
var stats = new PipelineRunStats();
134+
var callbacks = new RecordingCallbacks();
135+
var context = new PipelineRunContext(
136+
new MemoryPipelineStorage(),
137+
new MemoryPipelineStorage(),
138+
new MemoryPipelineStorage(),
139+
new StubPipelineCache(),
140+
callbacks,
141+
stats,
142+
new PipelineState(),
143+
services);
144+
145+
var failure = new InvalidOperationException("fail");
146+
var pipeline = new WorkflowPipeline("failing", new[]
147+
{
148+
new WorkflowStep("good", (cfg, ctx, token) => ValueTask.FromResult(new WorkflowResult("done"))),
149+
new WorkflowStep("bad", (cfg, ctx, token) => throw failure),
150+
new WorkflowStep("skipped", (cfg, ctx, token) => ValueTask.FromResult(new WorkflowResult("nope")))
151+
});
152+
153+
var executor = new PipelineExecutor(new NullLogger<PipelineExecutor>());
154+
var results = new List<PipelineRunResult>();
155+
156+
await foreach (var result in executor.ExecuteAsync(pipeline, new GraphRagConfig(), context))
157+
{
158+
results.Add(result);
159+
}
160+
161+
Assert.Equal(2, results.Count);
162+
Assert.Null(results[0].Errors);
163+
var errorResult = results[1];
164+
Assert.NotNull(errorResult.Errors);
165+
var captured = Assert.Single(errorResult.Errors!);
166+
Assert.Same(failure, captured);
167+
168+
Assert.Equal(new[] { "good", "bad" }, callbacks.WorkflowStarts);
169+
Assert.Equal(callbacks.WorkflowStarts, callbacks.WorkflowEnds);
170+
Assert.Equal(2, callbacks.PipelineEndResults?.Count);
171+
172+
Assert.True(stats.Workflows.ContainsKey("good"));
173+
Assert.True(stats.Workflows.ContainsKey("bad"));
174+
Assert.False(stats.Workflows.ContainsKey("skipped"));
175+
Assert.True(stats.TotalRuntime >= 0);
176+
}
177+
178+
private sealed class RecordingCallbacks : IWorkflowCallbacks
179+
{
180+
public IReadOnlyList<string>? PipelineStartedWith { get; private set; }
181+
public List<string> WorkflowStarts { get; } = new();
182+
public List<string> WorkflowEnds { get; } = new();
183+
public IReadOnlyList<PipelineRunResult>? PipelineEndResults { get; private set; }
184+
public List<ProgressSnapshot> ProgressUpdates { get; } = new();
185+
186+
public void PipelineStart(IReadOnlyList<string> names)
187+
{
188+
PipelineStartedWith = names.ToArray();
189+
}
190+
191+
public void PipelineEnd(IReadOnlyList<PipelineRunResult> results)
192+
{
193+
PipelineEndResults = results.ToArray();
194+
}
195+
196+
public void WorkflowStart(string name, object? instance)
197+
{
198+
WorkflowStarts.Add(name);
199+
}
200+
201+
public void WorkflowEnd(string name, object? instance)
202+
{
203+
WorkflowEnds.Add(name);
204+
}
205+
206+
public void ReportProgress(ProgressSnapshot progress)
207+
{
208+
ProgressUpdates.Add(progress);
209+
}
210+
}
78211
}

0 commit comments

Comments
 (0)