Skip to content

Commit 3972597

Browse files
Add schema validation in the InputSchema property setter.
1 parent a5f16ea commit 3972597

File tree

6 files changed

+110
-7
lines changed

6 files changed

+110
-7
lines changed

src/ModelContextProtocol/Protocol/Types/Tool.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Text.Json;
1+
using ModelContextProtocol.Utils.Json;
2+
using System.Text.Json;
23
using System.Text.Json.Serialization;
34

45
namespace ModelContextProtocol.Protocol.Types;
@@ -24,6 +25,23 @@ public class Tool
2425
/// <summary>
2526
/// A JSON Schema object defining the expected parameters for the tool.
2627
/// </summary>
28+
/// <remarks>
29+
/// Needs to a valid JSON schema object that additionally is of type object.
30+
/// </remarks>
2731
[JsonPropertyName("inputSchema")]
28-
public JsonElement InputSchema { get; set; }
32+
public JsonElement InputSchema
33+
{
34+
get => _inputSchema;
35+
set
36+
{
37+
if (!McpJsonUtilities.IsValidMcpToolSchema(value))
38+
{
39+
throw new ArgumentException("The specified document is not a valid MPC tool JSON schema.", nameof(InputSchema));
40+
}
41+
42+
_inputSchema = value;
43+
}
44+
}
45+
46+
private JsonElement _inputSchema = McpJsonUtilities.DefaultMcpToolSchema;
2947
}

src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,31 @@ private static JsonSerializerOptions CreateDefaultOptions()
7777
internal static JsonTypeInfo<T> GetTypeInfo<T>(this JsonSerializerOptions options) =>
7878
(JsonTypeInfo<T>)options.GetTypeInfo(typeof(T));
7979

80+
internal static JsonElement DefaultMcpToolSchema = ParseJsonElement("{\"type\":\"object\"}"u8);
81+
internal static bool IsValidMcpToolSchema(JsonElement element)
82+
{
83+
if (element.ValueKind is not JsonValueKind.Object)
84+
{
85+
return false;
86+
}
87+
88+
foreach (JsonProperty property in element.EnumerateObject())
89+
{
90+
if (property.NameEquals("type"))
91+
{
92+
if (property.Value.ValueKind is not JsonValueKind.String ||
93+
!property.Value.ValueEquals("object"))
94+
{
95+
return false;
96+
}
97+
98+
return true; // No need to check other properties
99+
}
100+
}
101+
102+
return false; // No type keyword found.
103+
}
104+
80105
// Keep in sync with CreateDefaultOptions above.
81106
[JsonSourceGenerationOptions(JsonSerializerDefaults.Web,
82107
UseStringEnumConverter = true,
@@ -98,4 +123,11 @@ internal static JsonTypeInfo<T> GetTypeInfo<T>(this JsonSerializerOptions option
98123
[JsonSerializable(typeof(InitializeResult))]
99124
[JsonSerializable(typeof(CallToolResponse))]
100125
internal sealed partial class JsonContext : JsonSerializerContext;
126+
127+
private static JsonElement ParseJsonElement(ReadOnlySpan<byte> utf8Json)
128+
{
129+
Utf8JsonReader reader = new(utf8Json);
130+
return JsonElement.ParseValue(ref reader);
131+
}
132+
101133
}

tests/ModelContextProtocol.TestServer/Program.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ private static ToolsCapability ConfigureTools()
9292
{
9393
Name = "echo",
9494
Description = "Echoes the input back to the client.",
95-
InputSchema = JsonSerializer.SerializeToElement("""
95+
InputSchema = JsonSerializer.Deserialize<JsonElement>("""
9696
{
9797
"type": "object",
9898
"properties": {
@@ -109,7 +109,7 @@ private static ToolsCapability ConfigureTools()
109109
{
110110
Name = "sampleLLM",
111111
Description = "Samples from an LLM using MCP's sampling feature.",
112-
InputSchema = JsonSerializer.SerializeToElement("""
112+
InputSchema = JsonSerializer.Deserialize<JsonElement>("""
113113
{
114114
"type": "object",
115115
"properties": {

tests/ModelContextProtocol.TestSseServer/Program.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st
122122
{
123123
Name = "echo",
124124
Description = "Echoes the input back to the client.",
125-
InputSchema = JsonSerializer.SerializeToElement("""
125+
InputSchema = JsonSerializer.Deserialize<JsonElement>("""
126126
{
127127
"type": "object",
128128
"properties": {
@@ -139,7 +139,7 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st
139139
{
140140
Name = "sampleLLM",
141141
Description = "Samples from an LLM using MCP's sampling feature.",
142-
InputSchema = JsonSerializer.SerializeToElement("""
142+
InputSchema = JsonSerializer.Deserialize<JsonElement>("""
143143
{
144144
"type": "object",
145145
"properties": {
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using ModelContextProtocol.Protocol.Types;
2+
using System.Text.Json;
3+
4+
namespace ModelContextProtocol.Tests.Protocol;
5+
6+
public static class ProtocolTypeTests
7+
{
8+
[Fact]
9+
public static void ToolInputSchema_HasValidDefaultSchema()
10+
{
11+
var tool = new Tool();
12+
JsonElement jsonElement = tool.InputSchema;
13+
14+
Assert.Equal(JsonValueKind.Object, jsonElement.ValueKind);
15+
Assert.Single(jsonElement.EnumerateObject());
16+
Assert.True(jsonElement.TryGetProperty("type", out JsonElement typeElement));
17+
Assert.Equal(JsonValueKind.String, typeElement.ValueKind);
18+
Assert.Equal("object", typeElement.GetString());
19+
}
20+
21+
[Theory]
22+
[InlineData("null")]
23+
[InlineData("false")]
24+
[InlineData("true")]
25+
[InlineData("3.5e3")]
26+
[InlineData("[]")]
27+
[InlineData("{}")]
28+
[InlineData("""{"properties":{}}""")]
29+
[InlineData("""{"type":"number"}""")]
30+
[InlineData("""{"type":"array"}""")]
31+
[InlineData("""{"type":["object"]}""")]
32+
public static void ToolInputSchema_RejectsInvalidSchemaDocuments(string invalidSchema)
33+
{
34+
using var document = JsonDocument.Parse(invalidSchema);
35+
var tool = new Tool();
36+
37+
Assert.Throws<ArgumentException>(() => tool.InputSchema = document.RootElement);
38+
}
39+
40+
[Theory]
41+
[InlineData("""{"type":"object"}""")]
42+
[InlineData("""{"type":"object", "properties": {}, "required" : [] }""")]
43+
[InlineData("""{"type":"object", "title": "MyAwesomeTool", "description": "It's awesome!", "properties": {}, "required" : ["NotAParam"] }""")]
44+
public static void ToolInputSchema_AcceptsValidSchemaDocuments(string validSchema)
45+
{
46+
using var document = JsonDocument.Parse(validSchema);
47+
var tool = new Tool();
48+
49+
tool.InputSchema = document.RootElement;
50+
Assert.True(JsonElement.DeepEquals(document.RootElement, tool.InputSchema));
51+
}
52+
}

tests/ModelContextProtocol.Tests/Utils/TestServerTransport.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Threading.Channels;
22
using ModelContextProtocol.Protocol.Messages;
33
using ModelContextProtocol.Protocol.Transport;
4+
using ModelContextProtocol.Protocol.Types;
45

56
namespace ModelContextProtocol.Tests.Utils;
67

@@ -71,7 +72,7 @@ private async Task Sampling(JsonRpcRequest request, CancellationToken cancellati
7172
await WriteMessageAsync(new JsonRpcResponse
7273
{
7374
Id = request.Id,
74-
Result = new Protocol.Types.CreateMessageResult { Content = new(), Model = "model", Role = "role" }
75+
Result = new CreateMessageResult { Content = new(), Model = "model", Role = "role" }
7576
}, cancellationToken);
7677
}
7778

0 commit comments

Comments
 (0)