Skip to content

Commit fb63193

Browse files
committed
fixes for function calling
1 parent 3e880a3 commit fb63193

File tree

9 files changed

+137
-45
lines changed

9 files changed

+137
-45
lines changed

.github/workflows/dotnet.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ jobs:
2828
run: dotnet restore
2929

3030
- name: Build
31-
run: dotnet build --no-restore
31+
run: dotnet build --no-restore --configuration Release
3232

3333
- name: Test and Collect Code Coverage
34-
run: dotnet test -p:CollectCoverage=true -p:CoverletOutput=coverage/
34+
run: dotnet test --configuration Release -p:CollectCoverage=true -p:CoverletOutput=coverage/
3535

3636
- name: Copy coverage files
3737
run: |

Directory.Build.props

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
<RepositoryUrl>https://github.com/managedcode/together-dotnet</RepositoryUrl>
2626
<PackageProjectUrl>https://github.com/managedcode/together-dotnet</PackageProjectUrl>
2727
<Product>Together.AI .NET/C# SDK</Product>
28-
<Version>0.0.1</Version>
29-
<PackageVersion>0.0.1</PackageVersion>
28+
<Version>0.0.2</Version>
29+
<PackageVersion>0.0.2</PackageVersion>
3030

3131
</PropertyGroup>
3232
<PropertyGroup Condition="'$(GITHUB_ACTIONS)' == 'true'">

Together.SemanticKernel/Services/TogetherChatCompletionService.cs

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync
6060
Stream = false
6161
};
6262

63-
var toolConfig = GetToolConfiguration(kernel, executionSettings, requestIndex);
63+
var toolConfig = GetToolConfiguration(kernel, chatHistory, executionSettings, requestIndex);
6464
if (toolConfig.HasTools)
6565
{
6666
ConfigureTools(kernel, request);
@@ -81,13 +81,13 @@ public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync
8181
}
8282

8383
var result = response.Choices.First();
84-
var messageContent = CreateChatMessageContent(new ChatCompletionMessage
84+
var message = new ChatCompletionMessage
8585
{
8686
Role = result.Message.Role,
8787
Content = result.Message.Content,
8888
ToolCalls = result.Message
8989
.ToolCalls
90-
.Select(t => new ToolCall
90+
?.Select(t => new ToolCall
9191
{
9292
Id = t.Id,
9393
Type = t.Type,
@@ -98,27 +98,34 @@ public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync
9898
}
9999
})
100100
.ToList()
101-
});
102-
101+
};
102+
103+
var messageContent = CreateChatMessageContent(message);
104+
103105
// If no tool calls or no auto-invoke, return response
104106
if (!toolConfig.AutoInvoke || result.Message.ToolCalls?.Any() != true)
105107
{
106-
return new[] { messageContent };
108+
return [messageContent];
107109
}
110+
111+
//add history
112+
chatHistory.Add(messageContent);
108113

109114
// Process tool calls asynchronously
110115
foreach (var toolCall in result.Message.ToolCalls)
111116
{
112-
if (!await ProcessToolCallAsync(new ToolCall
117+
var process = await ProcessToolCallAsync(new ToolCall
118+
{
119+
Id = toolCall.Id,
120+
Type = toolCall.Type,
121+
Function = new FunctionCall
113122
{
114-
Id = toolCall.Id,
115-
Type = toolCall.Type,
116-
Function = new FunctionCall
117-
{
118-
Name = toolCall.Function.Name,
119-
Arguments = toolCall.Function.Arguments
120-
}
121-
}, kernel, chatHistory, cancellationToken))
123+
Name = toolCall.Function.Name,
124+
Arguments = toolCall.Function.Arguments
125+
}
126+
}, kernel, chatHistory, cancellationToken);
127+
128+
if (!process)
122129
{
123130
_logger.LogWarning("Failed to process tool call: {ToolCall}", toolCall.Function?.Name);
124131
}
@@ -221,7 +228,7 @@ private async Task<bool> ProcessToolCallAsync(ToolCall toolCall, Kernel? kernel,
221228
// Await asynchronous function invocation instead of using .Result
222229
var result = await function.InvokeAsync(kernel, args, cancellationToken);
223230

224-
chatHistory.Add(new ChatMessageContent(AuthorRole.Tool, result.GetValue<string>(),
231+
chatHistory.Add(new ChatMessageContent(AuthorRole.Tool, result.ToString(),
225232
metadata: new Dictionary<string, object?> { { "tool_call_id", toolCall.Id } }));
226233

227234
return true;
@@ -247,7 +254,7 @@ private KernelArguments ParseArguments(string argumentJson)
247254
}
248255
}
249256

250-
private ToolConfiguration GetToolConfiguration(Kernel? kernel, PromptExecutionSettings? settings, int requestIndex)
257+
private ToolConfiguration GetToolConfiguration(Kernel? kernel, ChatHistory chatHistory, PromptExecutionSettings? settings, int requestIndex)
251258
{
252259
if (kernel == null)
253260
{
@@ -262,6 +269,17 @@ private ToolConfiguration GetToolConfiguration(Kernel? kernel, PromptExecutionSe
262269
{
263270
return new ToolConfiguration(false, false, 0);
264271
}
272+
273+
if(settings is TogetherPromptExecutionSettings togetherSettings)
274+
{
275+
var call = togetherSettings.FunctionChoiceBehavior?.GetConfiguration(new FunctionChoiceBehaviorConfigurationContext(chatHistory)
276+
{
277+
Kernel = kernel
278+
});
279+
return new ToolConfiguration(hasTools, call?.AutoInvoke ?? false, 1);
280+
}
281+
282+
//TODO: check type of settings
265283

266284
// Check execution settings
267285
var autoInvoke = false;
@@ -353,10 +371,11 @@ private void ConfigureTools(Kernel kernel, ChatCompletionRequest request)
353371
})
354372
.ToList();
355373

374+
356375
request.ToolChoice = new ToolChoice
357376
{
358377
Type = "auto",
359-
Function = new FunctionToolChoice { Name = "auto" }
378+
Function = new FunctionToolChoice { Name = "auto" },
360379
};
361380
}
362381

Together.SemanticKernel/Together.SemanticKernel.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
</ItemGroup>
1010

1111
<ItemGroup>
12-
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="1.36.1"/>
12+
<PackageReference Include="Microsoft.SemanticKernel.Core" Version="1.38.0" />
1313
</ItemGroup>
1414

1515

Together.Tests/SemanticKernelIntegraionTests.cs

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ public class SemanticKernelIntegraionTests
1616
private static readonly string API_KEY = "API_KEY";
1717

1818

19+
private string TextModel = "mistralai/Mistral-7B-Instruct-v0.1";//"meta-llama/Llama-3.3-70B-Instruct-Turbo";
20+
private string ImageModel = "black-forest-labs/FLUX.1-dev";
21+
private string EmbeddedModel = "togethercomputer/m2-bert-80M-2k-retrieval";
22+
1923
[Fact
2024
#if !API_TEST
2125
(Skip = "This test is skipped because it requires a valid API key")
@@ -25,7 +29,7 @@ public class SemanticKernelIntegraionTests
2529
public async Task InvokeOpenAiPromptAsyncTest()
2630
{
2731
var kernel = Kernel.CreateBuilder()
28-
.AddOpenAIChatCompletion("meta-llama/Meta-Llama-3-70B-Instruct-Turbo", new Uri("https://api.together.xyz/v1"), API_KEY)
32+
.AddOpenAIChatCompletion(TextModel, new Uri("https://api.together.xyz/v1"), API_KEY)
2933
.Build();
3034

3135
var answer = await kernel.InvokePromptAsync("Hi");
@@ -38,10 +42,10 @@ public async Task InvokeOpenAiPromptAsyncTest()
3842
#endif
3943
]
4044
[Experimental("SKEXP0010")]
41-
public async Task FunctionCallTest()
45+
public async Task OpenAIFunctionCallTest()
4246
{
4347
var kernel = Kernel.CreateBuilder()
44-
.AddOpenAIChatCompletion("mistralai/Mistral-7B-Instruct-v0.1", new Uri("https://api.together.xyz/v1"), API_KEY)
48+
.AddOpenAIChatCompletion(TextModel, new Uri("https://api.together.xyz/v1"), API_KEY)
4549
.Build();
4650

4751
var call = false;
@@ -64,6 +68,39 @@ public async Task FunctionCallTest()
6468
answer.RenderedPrompt.ShouldNotBeEmpty();
6569
call.ShouldBeTrue();
6670
}
71+
72+
[Fact
73+
#if !API_TEST
74+
(Skip = "This test is skipped because it requires a valid API key")
75+
#endif
76+
]
77+
[Experimental("SKEXP0010")]
78+
public async Task FunctionCallTest()
79+
{
80+
var kernel = Kernel.CreateBuilder()
81+
.AddTogetherChatCompletion(TextModel, API_KEY)
82+
.Build();
83+
84+
var call = false;
85+
86+
kernel.Plugins.AddFromFunctions("time_plugin", [
87+
KernelFunctionFactory.CreateFromMethod(() =>
88+
{
89+
call = true;
90+
return DateTime.Now;
91+
}, "get_time", "Get the current time")
92+
]);
93+
94+
var message = await kernel.GetRequiredService<IChatCompletionService>()
95+
.GetChatMessageContentAsync("What is the current time?", new TogetherPromptExecutionSettings
96+
{
97+
FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
98+
}, kernel);
99+
100+
var answer = await kernel.InvokePromptAsync("What is the current time?");
101+
answer.RenderedPrompt.ShouldNotBeEmpty();
102+
call.ShouldBeTrue();
103+
}
67104

68105
[Fact
69106
#if !API_TEST
@@ -73,7 +110,7 @@ public async Task FunctionCallTest()
73110
public async Task InvokePromptAsyncTest()
74111
{
75112
var kernel = Kernel.CreateBuilder()
76-
.AddTogetherChatCompletion("mistralai/Mistral-7B-Instruct-v0.1", API_KEY)
113+
.AddTogetherChatCompletion(TextModel, API_KEY)
77114
.Build();
78115

79116
var answer = await kernel.InvokePromptAsync("Hi");
@@ -88,28 +125,47 @@ public async Task InvokePromptAsyncTest()
88125
public async Task CompletionTest()
89126
{
90127
var kernel = Kernel.CreateBuilder()
91-
.AddTogetherChatCompletion("mistralai/Mistral-7B-Instruct-v0.1", API_KEY)
128+
.AddTogetherChatCompletion(TextModel, API_KEY)
92129
.Build();
93130

94-
var call = false;
131+
var call1 = false;
132+
var call2 = false;
95133

96134
kernel.Plugins.AddFromFunctions("time_plugin", [
97135
KernelFunctionFactory.CreateFromMethod(() =>
98136
{
99-
call = true;
137+
call1 = true;
100138
return DateTime.Now;
101139
}, "get_time", "Get the current time")
102140
]);
141+
kernel.Plugins.AddFromFunctions("news_plugin", [
142+
KernelFunctionFactory.CreateFromMethod((string day) =>
143+
{
144+
call2 = true;
145+
return $"Today is {day}, and we have 5 news items.";
146+
}, "get_news_for_day", "get news for specific day")
147+
]);
103148

104-
var message = await kernel.GetRequiredService<IChatCompletionService>()
105-
.GetChatMessageContentAsync("What is the current time?", new OpenAIPromptExecutionSettings
149+
var message1 = await kernel.GetRequiredService<IChatCompletionService>()
150+
.GetChatMessageContentAsync("What is the current time?", new TogetherPromptExecutionSettings
151+
{
152+
FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
153+
}, kernel);
154+
155+
var message2 = await kernel.GetRequiredService<IChatCompletionService>()
156+
.GetChatMessageContentAsync("how many news we have for monday?", new TogetherPromptExecutionSettings
106157
{
107158
FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
108159
}, kernel);
109160

110-
var answer = await kernel.InvokePromptAsync("What is the current time?");
111-
answer.RenderedPrompt.ShouldNotBeEmpty();
112-
call.ShouldBeTrue();
161+
var answer1 = await kernel.InvokePromptAsync("What is the current time?");
162+
var answer2 = await kernel.InvokePromptAsync("how many news we have for monday?");
163+
message1.Content.ShouldNotBeEmpty();
164+
message2.Content.ShouldNotBeEmpty();
165+
answer1.Metadata.Count.ShouldBePositive();
166+
answer2.Metadata.Count.ShouldBePositive();
167+
call1.ShouldBeTrue();
168+
call2.ShouldBeTrue();
113169
}
114170

115171
[Fact
@@ -121,7 +177,7 @@ public async Task CompletionTest()
121177
public async Task ImageTest()
122178
{
123179
var kernel = Kernel.CreateBuilder()
124-
.AddTogetherTextToImage("black-forest-labs/FLUX.1-dev", API_KEY)
180+
.AddTogetherTextToImage(ImageModel, API_KEY)
125181
.Build();
126182

127183

@@ -152,7 +208,7 @@ public async Task ImageTest()
152208
public async Task Embedded()
153209
{
154210
var kernel = Kernel.CreateBuilder()
155-
.AddTogetherTextEmbeddingGeneration("togethercomputer/m2-bert-80M-2k-retrieval", API_KEY)
211+
.AddTogetherTextEmbeddingGeneration(EmbeddedModel, API_KEY)
156212
.Build();
157213

158214

Together.Tests/Together.Tests.csproj

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55
<IsPackable>false</IsPackable>
66
<IsTestProject>true</IsTestProject>
77
</PropertyGroup>
8+
9+
<PropertyGroup Condition=" '$(Configuration)' == 'Debug' ">
10+
<DefineConstants>TRACE; API_TEST</DefineConstants>
11+
</PropertyGroup>
812

913
<ItemGroup>
1014
<PackageReference Include="coverlet.collector" Version="6.0.4">
1115
<PrivateAssets>all</PrivateAssets>
1216
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
1317
</PackageReference>
1418
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.13.0"/>
15-
<PackageReference Include="Microsoft.SemanticKernel" Version="1.36.1"/>
19+
<PackageReference Include="Microsoft.SemanticKernel" Version="1.38.0" />
1620
<PackageReference Include="Shouldly" Version="4.3.0"/>
1721
<PackageReference Include="System.Linq.Async" Version="6.0.1"/>
1822
<PackageReference Include="xunit" Version="2.9.3"/>

Together.Tests/TogetherClientIntegraionTests.cs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ public class TogetherClientIntegraionTests
1111
{
1212
private static readonly string API_KEY = "API_KEY";
1313

14+
private string TextModel = "meta-llama/Llama-3.3-70B-Instruct-Turbo";
15+
private string ImageModel = "black-forest-labs/FLUX.1-dev";
16+
private string EmbeddedModel = "togethercomputer/m2-bert-80M-2k-retrieval";
17+
1418
private TogetherClient CreateTogetherClient()
1519
{
1620
return new TogetherClient(API_KEY);
@@ -29,7 +33,7 @@ public async Task CompletionTest()
2933
var responseAsync = await client.Completions.CreateAsync(new CompletionRequest
3034
{
3135
Prompt = "Hi",
32-
Model = "meta-llama/Meta-Llama-3-70B-Instruct-Turbo",
36+
Model = TextModel,
3337
MaxTokens = 20
3438
});
3539

@@ -56,7 +60,7 @@ public async Task ChatCompletionTest()
5660
Content = "Hi"
5761
}
5862
},
59-
Model = "meta-llama/Meta-Llama-3-70B-Instruct-Turbo",
63+
Model = TextModel,
6064
MaxTokens = 20
6165
});
6266

@@ -84,7 +88,7 @@ public async Task StreamChatCompletionTest()
8488
Content = "Hi"
8589
}
8690
},
87-
Model = "meta-llama/Meta-Llama-3-70B-Instruct-Turbo",
91+
Model = TextModel,
8892
MaxTokens = 20,
8993
Stream = true
9094
})
@@ -107,7 +111,7 @@ public async Task EmbeddingTest()
107111
var responseAsync = await client.Embeddings.CreateAsync(new EmbeddingRequest
108112
{
109113
Input = "Hi",
110-
Model = "togethercomputer/m2-bert-80M-2k-retrieval"
114+
Model = EmbeddedModel
111115
});
112116

113117
Assert.NotNull(responseAsync.Data);
@@ -124,7 +128,7 @@ public async Task ImageTest()
124128

125129
var responseAsync = await client.Images.GenerateAsync(new ImageRequest
126130
{
127-
Model = "black-forest-labs/FLUX.1-dev",
131+
Model = ImageModel,
128132
Prompt = "Cats eating popcorn",
129133
N = 1,
130134
Steps = 10,

0 commit comments

Comments
 (0)