Skip to content

Commit 784df73

Browse files
committed
review
1 parent 8e93463 commit 784df73

File tree

4 files changed

+150
-15
lines changed

4 files changed

+150
-15
lines changed

src/ManagedCode.MCPGateway/McpGateway.cs

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -422,13 +422,13 @@ public async Task<McpGatewayInvokeResult> InvokeAsync(
422422

423423
if (!string.IsNullOrWhiteSpace(request.Query) &&
424424
!arguments.ContainsKey(QueryArgumentName) &&
425-
entry.Descriptor.RequiredArguments.Contains(QueryArgumentName, StringComparer.OrdinalIgnoreCase))
425+
SupportsArgument(entry.Descriptor, QueryArgumentName))
426426
{
427427
arguments[QueryArgumentName] = request.Query;
428428
}
429429

430-
MapRequestArgument(arguments, entry.Descriptor.RequiredArguments, ContextArgumentName, request.Context);
431-
MapRequestArgument(arguments, entry.Descriptor.RequiredArguments, ContextSummaryArgumentName, request.ContextSummary);
430+
MapRequestArgument(arguments, entry.Descriptor, ContextArgumentName, request.Context);
431+
MapRequestArgument(arguments, entry.Descriptor, ContextSummaryArgumentName, request.ContextSummary);
432432

433433
try
434434
{
@@ -948,13 +948,13 @@ private static void AppendJsonElementTerms(List<string> terms, string key, JsonE
948948

949949
private static void MapRequestArgument(
950950
IDictionary<string, object?> arguments,
951-
IReadOnlyList<string> requiredArguments,
951+
McpGatewayToolDescriptor descriptor,
952952
string argumentName,
953953
object? value)
954954
{
955955
if (value is null ||
956956
arguments.ContainsKey(argumentName) ||
957-
!requiredArguments.Contains(argumentName, StringComparer.OrdinalIgnoreCase))
957+
!SupportsArgument(descriptor, argumentName))
958958
{
959959
return;
960960
}
@@ -967,6 +967,39 @@ private static void MapRequestArgument(
967967
arguments[argumentName] = value;
968968
}
969969

970+
private static bool SupportsArgument(
971+
McpGatewayToolDescriptor descriptor,
972+
string argumentName)
973+
{
974+
if (descriptor.RequiredArguments.Contains(argumentName, StringComparer.OrdinalIgnoreCase))
975+
{
976+
return true;
977+
}
978+
979+
if (string.IsNullOrWhiteSpace(descriptor.InputSchemaJson))
980+
{
981+
return false;
982+
}
983+
984+
try
985+
{
986+
using var schemaDocument = JsonDocument.Parse(descriptor.InputSchemaJson);
987+
if (!schemaDocument.RootElement.TryGetProperty("properties", out var properties) ||
988+
properties.ValueKind != JsonValueKind.Object)
989+
{
990+
return false;
991+
}
992+
993+
return properties
994+
.EnumerateObject()
995+
.Any(property => string.Equals(property.Name, argumentName, StringComparison.OrdinalIgnoreCase));
996+
}
997+
catch (JsonException)
998+
{
999+
return false;
1000+
}
1001+
}
1002+
9701003
private static McpClientTool AttachInvocationMeta(McpClientTool tool, McpGatewayInvokeRequest request)
9711004
{
9721005
var meta = BuildInvocationMeta(request);
@@ -1039,12 +1072,7 @@ private static double CalculateLexicalScore(
10391072
return 0d;
10401073
}
10411074

1042-
var corpus = BuildSearchTerms(string.Join(
1043-
" ",
1044-
entry.Descriptor.ToolName,
1045-
entry.Descriptor.DisplayName,
1046-
entry.Descriptor.Description,
1047-
entry.Descriptor.SourceId));
1075+
var corpus = BuildSearchTerms(entry.Document);
10481076

10491077
var score = 0d;
10501078
foreach (var term in searchTerms)

src/ManagedCode.MCPGateway/Models/McpGatewayOptions.cs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,7 @@ private async Task<McpClient> GetClientAsync(
323323

324324
if (_clientTask is not null)
325325
{
326-
_client = await _clientTask.WaitAsync(cancellationToken);
327-
return _client;
326+
return await AwaitClientTaskAsync(_clientTask, cancellationToken);
328327
}
329328

330329
await _sync.WaitAsync(cancellationToken);
@@ -342,8 +341,35 @@ private async Task<McpClient> GetClientAsync(
342341
_sync.Release();
343342
}
344343

345-
_client = await _clientTask.WaitAsync(cancellationToken);
346-
return _client;
344+
return await AwaitClientTaskAsync(_clientTask, cancellationToken);
345+
}
346+
347+
private async Task<McpClient> AwaitClientTaskAsync(
348+
Task<McpClient> clientTask,
349+
CancellationToken cancellationToken)
350+
{
351+
try
352+
{
353+
_client = await clientTask.WaitAsync(cancellationToken);
354+
return _client;
355+
}
356+
catch when (clientTask.IsFaulted || clientTask.IsCanceled)
357+
{
358+
await _sync.WaitAsync(CancellationToken.None);
359+
try
360+
{
361+
if (ReferenceEquals(_clientTask, clientTask))
362+
{
363+
_clientTask = null;
364+
}
365+
}
366+
finally
367+
{
368+
_sync.Release();
369+
}
370+
371+
throw;
372+
}
347373
}
348374
}
349375

tests/ManagedCode.MCPGateway.Tests/Invocation/McpGatewayInvocationTests.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,28 @@ public async Task InvokeAsync_InvokesLocalFunctionAndMapsQueryArgument()
3232
await Assert.That((string)invokeResult.Output!).IsEqualTo("HELLO GATEWAY");
3333
}
3434

35+
[TUnit.Core.Test]
36+
public async Task InvokeAsync_MapsQueryArgumentWhenSchemaMarksItOptional()
37+
{
38+
await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options =>
39+
{
40+
options.AddTool(
41+
"local",
42+
CreateFunction(OptionalQueryEcho, "optional_query_echo", "Echo optional query text in uppercase."));
43+
});
44+
45+
var gateway = serviceProvider.GetRequiredService<IMcpGateway>();
46+
await gateway.BuildIndexAsync();
47+
48+
var invokeResult = await gateway.InvokeAsync(new McpGatewayInvokeRequest(
49+
ToolId: "local:optional_query_echo",
50+
Query: "hello gateway"));
51+
52+
await Assert.That(invokeResult.IsSuccess).IsTrue();
53+
await Assert.That(invokeResult.Output).IsTypeOf<string>();
54+
await Assert.That((string)invokeResult.Output!).IsEqualTo("HELLO GATEWAY");
55+
}
56+
3557
[TUnit.Core.Test]
3658
public async Task InvokeAsync_MapsContextSummaryToRequiredLocalArguments()
3759
{
@@ -317,6 +339,9 @@ private static AIFunction CreateFunction(Delegate callback, string name, string
317339

318340
private static string TextUppercase([Description("Text to uppercase.")] string query) => query.ToUpperInvariant();
319341

342+
private static string OptionalQueryEcho([Description("Text to uppercase.")] string? query = null)
343+
=> (query ?? "missing").ToUpperInvariant();
344+
320345
private static string EchoContextSummary(
321346
[Description("Main query text.")] string query,
322347
[Description("Execution context summary.")] string contextSummary)

tests/ManagedCode.MCPGateway.Tests/Search/McpGatewaySearchTests.cs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
using Microsoft.Extensions.AI;
66
using Microsoft.Extensions.DependencyInjection;
77

8+
using ModelContextProtocol.Client;
9+
810
namespace ManagedCode.MCPGateway.Tests;
911

1012
public sealed class McpGatewaySearchTests
@@ -112,6 +114,23 @@ public async Task SearchAsync_UsesContextDictionaryForLexicalFallback()
112114
await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:weather_search_forecast");
113115
}
114116

117+
[TUnit.Core.Test]
118+
public async Task SearchAsync_UsesSchemaTermsForLexicalFallback()
119+
{
120+
await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options =>
121+
{
122+
options.AddTool("local", CreateFunction(SearchGitHub, "github_search_issues", "Search GitHub issues and pull requests by user query."));
123+
options.AddTool("local", CreateFunction(FilterAdvisories, "advisory_lookup", "Lookup advisory records."));
124+
});
125+
var gateway = serviceProvider.GetRequiredService<IMcpGateway>();
126+
127+
await gateway.BuildIndexAsync();
128+
var searchResult = await gateway.SearchAsync("severity filter", maxResults: 1);
129+
130+
await Assert.That(searchResult.RankingMode).IsEqualTo("lexical");
131+
await Assert.That(searchResult.Matches[0].ToolId).IsEqualTo("local:advisory_lookup");
132+
}
133+
115134
[TUnit.Core.Test]
116135
public async Task SearchAsync_UsesBrowseModeWhenQueryAndContextAreMissing()
117136
{
@@ -248,6 +267,40 @@ public async Task BuildIndexAsync_RebuildsAfterNewToolIsRegistered()
248267
await Assert.That(secondBuild.ToolCount).IsEqualTo(2);
249268
}
250269

270+
[TUnit.Core.Test]
271+
public async Task BuildIndexAsync_RetriesFailedMcpClientFactoryOnNextBuild()
272+
{
273+
await using var serverHost = await TestMcpServerHost.StartAsync();
274+
275+
var attempts = 0;
276+
await using var serviceProvider = GatewayTestServiceProviderFactory.Create(options =>
277+
{
278+
options.AddMcpClientFactory(
279+
"test-mcp",
280+
async _ =>
281+
{
282+
attempts++;
283+
if (attempts == 1)
284+
{
285+
throw new InvalidOperationException("temporary startup failure");
286+
}
287+
288+
return serverHost.Client;
289+
},
290+
disposeClient: false);
291+
});
292+
var gateway = serviceProvider.GetRequiredService<IMcpGateway>();
293+
294+
var firstBuild = await gateway.BuildIndexAsync();
295+
var secondBuild = await gateway.BuildIndexAsync();
296+
297+
await Assert.That(attempts).IsEqualTo(2);
298+
await Assert.That(firstBuild.ToolCount).IsEqualTo(0);
299+
await Assert.That(firstBuild.Diagnostics.Any(static diagnostic => diagnostic.Code == "source_load_failed")).IsTrue();
300+
await Assert.That(secondBuild.ToolCount).IsEqualTo(3);
301+
await Assert.That(secondBuild.Diagnostics.Any(static diagnostic => diagnostic.Code == "source_load_failed")).IsFalse();
302+
}
303+
251304
[TUnit.Core.Test]
252305
public async Task ListToolsAsync_BuildsIndexOnDemand()
253306
{
@@ -281,4 +334,7 @@ private static AIFunction CreateFunction(Delegate callback, string name, string
281334
private static string SearchGitHubAgain([Description("Search query text.")] string query) => $"github-duplicate:{query}";
282335

283336
private static string SearchWeather([Description("City or weather request text.")] string query) => $"weather:{query}";
337+
338+
private static string FilterAdvisories([Description("Severity filter to apply to advisory lookups.")] string severity)
339+
=> $"advisory:{severity}";
284340
}

0 commit comments

Comments
 (0)