Skip to content

Commit b07fda3

Browse files
authored
Update OnnxRuntimeGenAIChatClient with chat template and guidance (microsoft#1533)
1 parent 4dd97c4 commit b07fda3

File tree

1 file changed

+47
-11
lines changed

1 file changed

+47
-11
lines changed

src/csharp/OnnxRuntimeGenAIChatClient.cs

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
using System.Collections.Generic;
66
using System.Runtime.CompilerServices;
77
using System.Text;
8+
using System.Text.Json;
9+
using System.Text.Json.Serialization;
810
using System.Threading;
911
using System.Threading.Tasks;
1012
using Microsoft.Extensions.AI;
@@ -14,7 +16,7 @@
1416
namespace Microsoft.ML.OnnxRuntimeGenAI;
1517

1618
/// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with an ONNX Runtime GenAI <see cref="Model"/>.</summary>
17-
public sealed class OnnxRuntimeGenAIChatClient : IChatClient
19+
public sealed partial class OnnxRuntimeGenAIChatClient : IChatClient
1820
{
1921
/// <summary>Options used to configure the instance's behavior.</summary>
2022
private readonly OnnxRuntimeGenAIChatClientOptions? _options;
@@ -220,11 +222,11 @@ generator.ConversationId is null ||
220222
{
221223
ConversationId = generator.ConversationId,
222224
Contents = [new UsageContent(new()
223-
{
224-
InputTokenCount = inputTokens,
225-
OutputTokenCount = outputTokens,
226-
TotalTokenCount = inputTokens + outputTokens,
227-
})],
225+
{
226+
InputTokenCount = inputTokens,
227+
OutputTokenCount = outputTokens,
228+
TotalTokenCount = inputTokens + outputTokens,
229+
})],
228230
CreatedAt = DateTimeOffset.UtcNow,
229231
FinishReason = options is not null && options.MaxOutputTokens <= outputTokens ? ChatFinishReason.Length : ChatFinishReason.Stop,
230232
MessageId = messageId,
@@ -251,7 +253,7 @@ generator.ConversationId is null ||
251253
throw new ArgumentNullException(nameof(serviceType));
252254
}
253255

254-
return
256+
return
255257
serviceKey is not null ? null :
256258
serviceType == typeof(ChatClientMetadata) ? _metadata :
257259
serviceType == typeof(Model) ? _model :
@@ -267,15 +269,41 @@ private bool IsStop(string token, ChatOptions? options) =>
267269
_options?.StopSequences?.Contains(token) is true;
268270

269271
/// <summary>Formats messages into a prompt using a default format.</summary>
270-
private static string FormatPromptDefault(IEnumerable<ChatMessage> messages, ChatOptions? options)
272+
private string FormatPromptDefault(IEnumerable<ChatMessage> messages, ChatOptions? options)
271273
{
272-
StringBuilder sb = new();
274+
SerializableMessage m = new();
275+
276+
StringBuilder prompt = new();
277+
string separator = "";
278+
prompt.Append('[');
273279
foreach (var message in messages)
274280
{
275-
sb.Append(message).AppendLine();
281+
if (message.Text is string text)
282+
{
283+
prompt.Append(separator);
284+
separator = ",";
285+
286+
m.Role = message.Role.Value;
287+
m.Content = text;
288+
prompt.Append(JsonSerializer.Serialize(m, OnnxJsonContext.Default.SerializableMessage));
289+
}
276290
}
291+
prompt.Append(']');
292+
293+
return _tokenizer.ApplyChatTemplate(
294+
template_str: null,
295+
messages: prompt.ToString(),
296+
tools: null,
297+
add_generation_prompt: true);
298+
}
299+
300+
private sealed class SerializableMessage
301+
{
302+
[JsonPropertyName("role")]
303+
public string Role { get; set; } = string.Empty;
277304

278-
return sb.ToString();
305+
[JsonPropertyName("content")]
306+
public string Content { get; set; } = string.Empty;
279307
}
280308

281309
/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
@@ -351,6 +379,11 @@ private static void UpdateGeneratorParamsFromOptions(GeneratorParams generatorPa
351379
}
352380
}
353381
}
382+
383+
if (options.ResponseFormat is ChatResponseFormatJson json)
384+
{
385+
generatorParams.SetGuidance("json_schema", json.Schema is { } schema ? schema.ToString() : "{}");
386+
}
354387
}
355388

356389
private sealed class CachedGenerator(Generator generator) : IDisposable
@@ -362,6 +395,9 @@ private sealed class CachedGenerator(Generator generator) : IDisposable
362395
public void Dispose() => Generator?.Dispose();
363396
}
364397

398+
[JsonSerializable(typeof(SerializableMessage))]
399+
private partial class OnnxJsonContext : JsonSerializerContext;
400+
365401
/// <summary>Polyfill for Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);</summary>
366402
private sealed class YieldAwaiter : INotifyCompletion
367403
{

0 commit comments

Comments
 (0)