Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ public sealed class FunctionCallContentBuilder
private Dictionary<string, string>? _functionCallIdsByIndex = null;
private Dictionary<string, string>? _functionNamesByIndex = null;
private Dictionary<string, StringBuilder>? _functionArgumentBuildersByIndex = null;
private readonly JsonSerializerOptions? _jsonSerializerOptions;
private readonly JsonSerializerOptions? _jsonSerializerOptions = null;
private readonly bool _retainArgumentTypes = false;

/// <summary>
/// Creates a new instance of the <see cref="FunctionCallContentBuilder"/> class.
Expand All @@ -33,10 +34,12 @@ public FunctionCallContentBuilder()
/// Creates a new instance of the <see cref="FunctionCallContentBuilder"/> class.
/// </summary>
/// <param name="jsonSerializerOptions">The <see cref="JsonSerializerOptions"/> to use for deserializing function arguments.</param>
/// <param name="retainArgumentTypes">A value indicating whether the types of function arguments provided by the AI model are retained by SK or not. By default <see langword="false"/>.</param>
[Experimental("SKEXP0120")]
public FunctionCallContentBuilder(JsonSerializerOptions jsonSerializerOptions)
public FunctionCallContentBuilder(JsonSerializerOptions? jsonSerializerOptions = null, bool retainArgumentTypes = false)
{
this._jsonSerializerOptions = jsonSerializerOptions;
this._retainArgumentTypes = retainArgumentTypes;
}

/// <summary>
Expand Down Expand Up @@ -146,7 +149,7 @@ public IReadOnlyList<FunctionCallContent> Build()
arguments = JsonSerializer.Deserialize<KernelArguments>(argumentsString);
}

if (arguments is { Count: > 0 })
if (arguments is { Count: > 0 } && !this._retainArgumentTypes)
{
var names = arguments.Names.ToArray();
foreach (var name in names)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,60 @@ public void ItShouldCaptureArgumentsDeserializationException(JsonSerializerOptio
Assert.NotNull(functionCall.Exception);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public void ItShouldRetainArgumentTypesIfSpecified(bool retain)
{
// Arrange
var sut = new FunctionCallContentBuilder(null, retainArgumentTypes: retain);

// Act
var update1 = CreateStreamingContentWithFunctionCallUpdate(choiceIndex: 1, functionCallIndex: 2, callId: "f_101", name: null, arguments: null);
sut.Append(update1);

var update2 = CreateStreamingContentWithFunctionCallUpdate(choiceIndex: 1, functionCallIndex: 2, callId: null, name: "WeatherUtils-GetTemperature", arguments: null);
sut.Append(update2);

var update3 = CreateStreamingContentWithFunctionCallUpdate(choiceIndex: 1, functionCallIndex: 2, callId: null, name: null, arguments: "{\"city\":");
sut.Append(update3);

var update4 = CreateStreamingContentWithFunctionCallUpdate(choiceIndex: 1, functionCallIndex: 2, callId: null, name: null, arguments: "\"Seattle\",");
sut.Append(update4);

var update5 = CreateStreamingContentWithFunctionCallUpdate(choiceIndex: 1, functionCallIndex: 2, callId: null, name: null, arguments: "\"temperature\":");
sut.Append(update5);

var update6 = CreateStreamingContentWithFunctionCallUpdate(choiceIndex: 1, functionCallIndex: 2, callId: null, name: null, arguments: "20}");
sut.Append(update6);

var functionCalls = sut.Build();

// Assert
var functionCall = Assert.Single(functionCalls);

Assert.Equal("f_101", functionCall.Id);
Assert.Equal("WeatherUtils", functionCall.PluginName);
Assert.Equal("GetTemperature", functionCall.FunctionName);
Assert.NotNull(functionCall.Arguments);

if (retain)
{
var city = Assert.IsType<JsonElement>(functionCall.Arguments?["city"]);
Assert.Equal(JsonValueKind.String, city.ValueKind);
Assert.Equal("Seattle", city.GetString());

var temperature = Assert.IsType<JsonElement>(functionCall.Arguments?["temperature"]);
Assert.Equal(JsonValueKind.Number, temperature.ValueKind);
Assert.Equal(20, temperature.GetInt32());
}
else
{
Assert.Equal("Seattle", functionCall.Arguments?["city"]);
Assert.Equal("20", functionCall.Arguments?["temperature"]);
}
}

private static StreamingChatMessageContent CreateStreamingContentWithFunctionCallUpdate(int choiceIndex, int functionCallIndex, string? callId, string? name, string? arguments, int requestIndex = 0)
{
var content = new StreamingChatMessageContent(AuthorRole.Assistant, null);
Expand Down
Loading