Skip to content

Commit 76b907a

Browse files
committed
fb
1 parent 5703fc0 commit 76b907a

File tree

5 files changed

+110
-24
lines changed

5 files changed

+110
-24
lines changed

src/Utility/ChatTools.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,10 @@ internal string CallLocal(ChatToolCall call)
228228
var arguments = new List<object>();
229229
if (call.FunctionArguments != null)
230230
{
231-
ToolsUtility.ParseFunctionCallArgs(call.FunctionArguments, out arguments);
231+
if (!_methods.TryGetValue(call.FunctionName, out MethodInfo method))
232+
return $"I don't have a tool called {call.FunctionName}";
233+
234+
ToolsUtility.ParseFunctionCallArgs(method, call.FunctionArguments, out arguments);
232235
}
233236
return CallLocal(call.FunctionName, [.. arguments]);
234237
}

src/Utility/ResponseTools.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,10 @@ internal string CallLocal(FunctionCallResponseItem call)
230230
List<object> arguments = new();
231231
if (call.FunctionArguments != null)
232232
{
233-
ToolsUtility.ParseFunctionCallArgs(call.FunctionArguments, out arguments);
233+
if (!_methods.TryGetValue(call.FunctionName, out MethodInfo method))
234+
return $"I don't have a tool called {call.FunctionName}";
235+
236+
ToolsUtility.ParseFunctionCallArgs(method, call.FunctionArguments, out arguments);
234237
}
235238
return CallLocal(call.FunctionName, [.. arguments]);
236239
}

src/Utility/ToolsUtility.cs

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ internal static ReadOnlySpan<byte> ClrToJsonTypeUtf8(Type clrType) =>
3939
{
4040
Type t when t == typeof(double) => "number"u8,
4141
Type t when t == typeof(int) => "number"u8,
42+
Type t when t == typeof(long) => "number"u8,
43+
Type t when t == typeof(float) => "number"u8,
4244
Type t when t == typeof(string) => "string"u8,
4345
Type t when t == typeof(bool) => "bool"u8,
4446
_ => throw new NotImplementedException()
@@ -146,19 +148,36 @@ internal static float CosineSimilarity(ReadOnlySpan<float> x, ReadOnlySpan<float
146148
return result;
147149
}
148150

149-
internal static void ParseFunctionCallArgs(BinaryData functionCallArguments, out List<object> arguments)
151+
internal static void ParseFunctionCallArgs(MethodInfo method, BinaryData functionCallArguments, out List<object> arguments)
150152
{
151153
arguments = new List<object>();
152154
using var document = JsonDocument.Parse(functionCallArguments);
153-
foreach (JsonProperty argument in document.RootElement.EnumerateObject())
155+
var parameters = method.GetParameters();
156+
var argumentsByName = document.RootElement.EnumerateObject().ToDictionary(p => p.Name, p => p.Value);
157+
158+
foreach (var param in parameters)
154159
{
155-
arguments.Add(argument.Value.ValueKind switch
160+
if (!argumentsByName.TryGetValue(param.Name!, out var value))
161+
{
162+
if (param.HasDefaultValue)
163+
{
164+
arguments.Add(param.DefaultValue!);
165+
continue;
166+
}
167+
throw new JsonException($"Required parameter '{param.Name}' not found in function call arguments.");
168+
}
169+
170+
arguments.Add(value.ValueKind switch
156171
{
157-
JsonValueKind.String => argument.Value.GetString()!,
158-
JsonValueKind.Number => argument.Value.GetInt32(),
172+
JsonValueKind.String => value.GetString()!,
173+
JsonValueKind.Number when param.ParameterType == typeof(int) => value.GetInt32(),
174+
JsonValueKind.Number when param.ParameterType == typeof(long) => value.GetInt64(),
175+
JsonValueKind.Number when param.ParameterType == typeof(double) => value.GetDouble(),
176+
JsonValueKind.Number when param.ParameterType == typeof(float) => value.GetSingle(),
159177
JsonValueKind.True => true,
160178
JsonValueKind.False => false,
161-
_ => throw new NotImplementedException()
179+
JsonValueKind.Null when param.HasDefaultValue => param.DefaultValue!,
180+
_ => throw new NotImplementedException($"Conversion from {value.ValueKind} to {param.ParameterType.Name} is not implemented.")
162181
});
163182
}
164183
}

tests/Utility/ChatToolsTests.cs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ private class TestTools
2323
{
2424
public static string Echo(string message) => message;
2525
public static int Add(int a, int b) => a + b;
26+
public static double Multiply(double x, double y) => x * y;
27+
public static bool IsGreaterThan(long value1, long value2) => value1 > value2;
28+
public static float Divide(float numerator, float denominator) => numerator / denominator;
29+
public static string ConcatWithBool(string text, bool flag) => $"{text}:{flag}";
2630
}
2731

2832
private Mock<EmbeddingClient> mockEmbeddingClient;
@@ -39,9 +43,13 @@ public void CanAddLocalTools()
3943
var tools = new ChatTools();
4044
tools.AddFunctionTools(typeof(TestTools));
4145

42-
Assert.That(tools.Tools, Has.Count.EqualTo(2));
46+
Assert.That(tools.Tools, Has.Count.EqualTo(6));
4347
Assert.That(tools.Tools.Any(t => t.FunctionName == "Echo"));
4448
Assert.That(tools.Tools.Any(t => t.FunctionName == "Add"));
49+
Assert.That(tools.Tools.Any(t => t.FunctionName == "Multiply"));
50+
Assert.That(tools.Tools.Any(t => t.FunctionName == "IsGreaterThan"));
51+
Assert.That(tools.Tools.Any(t => t.FunctionName == "Divide"));
52+
Assert.That(tools.Tools.Any(t => t.FunctionName == "ConcatWithBool"));
4553
}
4654

4755
[Test]
@@ -53,17 +61,29 @@ public async Task CanCallToolsAsync()
5361
var toolCalls = new[]
5462
{
5563
ChatToolCall.CreateFunctionToolCall("call1", "Echo", BinaryData.FromString(@"{""message"": ""Hello""}")),
56-
ChatToolCall.CreateFunctionToolCall("call2", "Add", BinaryData.FromString(@"{""a"": 2, ""b"": 3}"))
64+
ChatToolCall.CreateFunctionToolCall("call2", "Add", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")),
65+
ChatToolCall.CreateFunctionToolCall("call3", "Multiply", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")),
66+
ChatToolCall.CreateFunctionToolCall("call4", "IsGreaterThan", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")),
67+
ChatToolCall.CreateFunctionToolCall("call5", "Divide", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")),
68+
ChatToolCall.CreateFunctionToolCall("call6", "ConcatWithBool", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}"))
5769
};
5870

5971
var results = await tools.CallAsync(toolCalls);
6072
var resultsList = results.ToList();
6173

62-
Assert.That(resultsList, Has.Count.EqualTo(2));
74+
Assert.That(resultsList, Has.Count.EqualTo(6));
6375
Assert.That(resultsList[0].ToolCallId, Is.EqualTo("call1"));
6476
Assert.That(resultsList[0].Content[0].Text, Is.EqualTo("Hello"));
6577
Assert.That(resultsList[1].ToolCallId, Is.EqualTo("call2"));
6678
Assert.That(resultsList[1].Content[0].Text, Is.EqualTo("5"));
79+
Assert.That(resultsList[2].ToolCallId, Is.EqualTo("call3"));
80+
Assert.That(resultsList[2].Content[0].Text, Is.EqualTo("7.5"));
81+
Assert.That(resultsList[3].ToolCallId, Is.EqualTo("call4"));
82+
Assert.That(resultsList[3].Content[0].Text, Is.EqualTo("True"));
83+
Assert.That(resultsList[4].ToolCallId, Is.EqualTo("call5"));
84+
Assert.That(resultsList[4].Content[0].Text, Is.EqualTo("5"));
85+
Assert.That(resultsList[5].ToolCallId, Is.EqualTo("call6"));
86+
Assert.That(resultsList[5].Content[0].Text, Is.EqualTo("Test:True"));
6787
}
6888

6989
[Test]
@@ -74,9 +94,13 @@ public void CreatesCompletionOptionsWithTools()
7494

7595
var options = tools.ToChatCompletionOptions();
7696

77-
Assert.That(options.Tools, Has.Count.EqualTo(2));
97+
Assert.That(options.Tools, Has.Count.EqualTo(6));
7898
Assert.That(options.Tools.Any(t => t.FunctionName == "Echo"));
7999
Assert.That(options.Tools.Any(t => t.FunctionName == "Add"));
100+
Assert.That(options.Tools.Any(t => t.FunctionName == "Multiply"));
101+
Assert.That(options.Tools.Any(t => t.FunctionName == "IsGreaterThan"));
102+
Assert.That(options.Tools.Any(t => t.FunctionName == "Divide"));
103+
Assert.That(options.Tools.Any(t => t.FunctionName == "ConcatWithBool"));
80104
}
81105

82106
[Test]

tests/Utility/ResponseToolsTests.cs

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ private class TestTools
2121
{
2222
public static string Echo(string message) => message;
2323
public static int Add(int a, int b) => a + b;
24+
public static double Multiply(double x, double y) => x * y;
25+
public static bool IsGreaterThan(long value1, long value2) => value1 > value2;
26+
public static float Divide(float numerator, float denominator) => numerator / denominator;
27+
public static string ConcatWithBool(string text, bool flag) => $"{text}:{flag}";
2428
}
2529

2630
private Mock<EmbeddingClient> mockEmbeddingClient;
@@ -37,9 +41,13 @@ public void CanAddLocalTools()
3741
var tools = new ResponseTools();
3842
tools.AddFunctionTools(typeof(TestTools));
3943

40-
Assert.That(tools.Tools, Has.Count.EqualTo(2));
44+
Assert.That(tools.Tools, Has.Count.EqualTo(6));
4145
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Echo")));
4246
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Add")));
47+
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Multiply")));
48+
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("IsGreaterThan")));
49+
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Divide")));
50+
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("ConcatWithBool")));
4351
}
4452

4553
[Test]
@@ -48,17 +56,42 @@ public async Task CanCallToolAsync()
4856
var tools = new ResponseTools();
4957
tools.AddFunctionTools(typeof(TestTools));
5058

51-
var toolCall = new FunctionCallResponseItem("call1", "Echo", BinaryData.FromString(@"{""message"": ""Hello""}"));
52-
var result = await tools.CallAsync(toolCall);
53-
54-
Assert.That(result.CallId, Is.EqualTo("call1"));
55-
Assert.That(result.FunctionOutput, Is.EqualTo("Hello"));
56-
57-
var addCall = new FunctionCallResponseItem("call2", "Add", BinaryData.FromString(@"{""a"": 2, ""b"": 3}"));
58-
result = await tools.CallAsync(addCall);
59+
var toolCalls = new[]
60+
{
61+
new FunctionCallResponseItem("call1", "Echo", BinaryData.FromString(@"{""message"": ""Hello""}")),
62+
new FunctionCallResponseItem("call2", "Add", BinaryData.FromString(@"{""a"": 2, ""b"": 3}")),
63+
new FunctionCallResponseItem("call3", "Multiply", BinaryData.FromString(@"{""x"": 2.5, ""y"": 3.0}")),
64+
new FunctionCallResponseItem("call4", "IsGreaterThan", BinaryData.FromString(@"{""value1"": 100, ""value2"": 50}")),
65+
new FunctionCallResponseItem("call5", "Divide", BinaryData.FromString(@"{""numerator"": 10.0, ""denominator"": 2.0}")),
66+
new FunctionCallResponseItem("call6", "ConcatWithBool", BinaryData.FromString(@"{""text"": ""Test"", ""flag"": true}"))
67+
};
5968

60-
Assert.That(result.CallId, Is.EqualTo("call2"));
61-
Assert.That(result.FunctionOutput, Is.EqualTo("5"));
69+
foreach (var toolCall in toolCalls)
70+
{
71+
var result = await tools.CallAsync(toolCall);
72+
Assert.That(result.CallId, Is.EqualTo(toolCall.CallId));
73+
switch (toolCall.CallId)
74+
{
75+
case "call1":
76+
Assert.That(result.FunctionOutput, Is.EqualTo("Hello"));
77+
break;
78+
case "call2":
79+
Assert.That(result.FunctionOutput, Is.EqualTo("5"));
80+
break;
81+
case "call3":
82+
Assert.That(result.FunctionOutput, Is.EqualTo("7.5"));
83+
break;
84+
case "call4":
85+
Assert.That(result.FunctionOutput, Is.EqualTo("True"));
86+
break;
87+
case "call5":
88+
Assert.That(result.FunctionOutput, Is.EqualTo("5"));
89+
break;
90+
case "call6":
91+
Assert.That(result.FunctionOutput, Is.EqualTo("Test:True"));
92+
break;
93+
}
94+
}
6295
}
6396

6497
[Test]
@@ -69,9 +102,13 @@ public void CreatesResponseOptionsWithTools()
69102

70103
var options = tools.ToResponseCreationOptions();
71104

72-
Assert.That(options.Tools, Has.Count.EqualTo(2));
105+
Assert.That(options.Tools, Has.Count.EqualTo(6));
73106
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Echo")));
74107
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Add")));
108+
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Multiply")));
109+
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("IsGreaterThan")));
110+
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("Divide")));
111+
Assert.That(tools.Tools.Any(t => ((string)t.GetType().GetProperty("Name").GetValue(t)).Contains("ConcatWithBool")));
75112
}
76113

77114
[Test]

0 commit comments

Comments
 (0)