5
5
using System . Collections . Generic ;
6
6
using System . Runtime . CompilerServices ;
7
7
using System . Text ;
8
+ using System . Text . Json ;
9
+ using System . Text . Json . Serialization ;
8
10
using System . Threading ;
9
11
using System . Threading . Tasks ;
10
12
using Microsoft . Extensions . AI ;
14
16
namespace Microsoft . ML . OnnxRuntimeGenAI ;
15
17
16
18
/// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with an ONNX Runtime GenAI <see cref="Model"/>.</summary>
17
- public sealed class OnnxRuntimeGenAIChatClient : IChatClient
19
+ public sealed partial class OnnxRuntimeGenAIChatClient : IChatClient
18
20
{
19
21
/// <summary>Options used to configure the instance's behavior.</summary>
20
22
private readonly OnnxRuntimeGenAIChatClientOptions ? _options ;
@@ -220,11 +222,11 @@ generator.ConversationId is null ||
220
222
{
221
223
ConversationId = generator . ConversationId ,
222
224
Contents = [ new UsageContent ( new ( )
223
- {
224
- InputTokenCount = inputTokens ,
225
- OutputTokenCount = outputTokens ,
226
- TotalTokenCount = inputTokens + outputTokens ,
227
- } ) ] ,
225
+ {
226
+ InputTokenCount = inputTokens ,
227
+ OutputTokenCount = outputTokens ,
228
+ TotalTokenCount = inputTokens + outputTokens ,
229
+ } ) ] ,
228
230
CreatedAt = DateTimeOffset . UtcNow ,
229
231
FinishReason = options is not null && options . MaxOutputTokens <= outputTokens ? ChatFinishReason . Length : ChatFinishReason . Stop ,
230
232
MessageId = messageId ,
@@ -251,7 +253,7 @@ generator.ConversationId is null ||
251
253
throw new ArgumentNullException ( nameof ( serviceType ) ) ;
252
254
}
253
255
254
- return
256
+ return
255
257
serviceKey is not null ? null :
256
258
serviceType == typeof ( ChatClientMetadata ) ? _metadata :
257
259
serviceType == typeof ( Model ) ? _model :
@@ -267,15 +269,41 @@ private bool IsStop(string token, ChatOptions? options) =>
267
269
_options ? . StopSequences ? . Contains ( token ) is true ;
268
270
269
271
/// <summary>Formats messages into a prompt using a default format.</summary>
270
- private static string FormatPromptDefault ( IEnumerable < ChatMessage > messages , ChatOptions ? options )
272
+ private string FormatPromptDefault ( IEnumerable < ChatMessage > messages , ChatOptions ? options )
271
273
{
272
- StringBuilder sb = new ( ) ;
274
+ SerializableMessage m = new ( ) ;
275
+
276
+ StringBuilder prompt = new ( ) ;
277
+ string separator = "" ;
278
+ prompt . Append ( '[' ) ;
273
279
foreach ( var message in messages )
274
280
{
275
- sb . Append ( message ) . AppendLine ( ) ;
281
+ if ( message . Text is string text )
282
+ {
283
+ prompt . Append ( separator ) ;
284
+ separator = "," ;
285
+
286
+ m . Role = message . Role . Value ;
287
+ m . Content = text ;
288
+ prompt . Append ( JsonSerializer . Serialize ( m , OnnxJsonContext . Default . SerializableMessage ) ) ;
289
+ }
276
290
}
291
+ prompt . Append ( ']' ) ;
292
+
293
+ return _tokenizer . ApplyChatTemplate (
294
+ template_str : null ,
295
+ messages : prompt . ToString ( ) ,
296
+ tools : null ,
297
+ add_generation_prompt : true ) ;
298
+ }
299
+
300
+ private sealed class SerializableMessage
301
+ {
302
+ [ JsonPropertyName ( "role" ) ]
303
+ public string Role { get ; set ; } = string . Empty ;
277
304
278
- return sb . ToString ( ) ;
305
+ [ JsonPropertyName ( "content" ) ]
306
+ public string Content { get ; set ; } = string . Empty ;
279
307
}
280
308
281
309
/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
@@ -351,6 +379,11 @@ private static void UpdateGeneratorParamsFromOptions(GeneratorParams generatorPa
351
379
}
352
380
}
353
381
}
382
+
383
+ if ( options . ResponseFormat is ChatResponseFormatJson json )
384
+ {
385
+ generatorParams . SetGuidance ( "json_schema" , json . Schema is { } schema ? schema . ToString ( ) : "{}" ) ;
386
+ }
354
387
}
355
388
356
389
private sealed class CachedGenerator ( Generator generator ) : IDisposable
@@ -362,6 +395,9 @@ private sealed class CachedGenerator(Generator generator) : IDisposable
362
395
public void Dispose ( ) => Generator ? . Dispose ( ) ;
363
396
}
364
397
398
+ [ JsonSerializable ( typeof ( SerializableMessage ) ) ]
399
+ private partial class OnnxJsonContext : JsonSerializerContext ;
400
+
365
401
/// <summary>Polyfill for Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);</summary>
366
402
private sealed class YieldAwaiter : INotifyCompletion
367
403
{
0 commit comments