1+ using System ;
2+ using System . Collections . Generic ;
3+ using System . Linq ;
4+ using System . Runtime . CompilerServices ;
5+ using System . Text ;
6+ using System . Threading ;
7+ using System . Threading . Tasks ;
8+ using LLama . Common ;
9+ using LLama . Sampling ;
10+ using Microsoft . Extensions . AI ;
11+
12+ namespace LLama . Abstractions ;
13+
14+ /// <summary>
15+ /// Extension methods to the <see cref="LLamaExecutorExtensions" /> interface.
16+ /// </summary>
17+ public static class LLamaExecutorExtensions
18+ {
19+ /// <summary>Gets an <see cref="IChatClient"/> instance for the specified <see cref="ILLamaExecutor"/>.</summary>
20+ /// <param name="executor">The executor.</param>
21+ /// <param name="historyTransform">The <see cref="IHistoryTransform"/> to use to transform an input list messages into a prompt.</param>
22+ /// <param name="outputTransform">The <see cref="ITextStreamTransform"/> to use to transform the output into text.</param>
23+ /// <returns>An <see cref="IChatClient"/> instance for the provided <see cref="ILLamaExecutor" />.</returns>
24+ /// <exception cref="ArgumentNullException"><paramref name="executor"/> is null.</exception>
25+ public static IChatClient AsChatClient (
26+ this ILLamaExecutor executor ,
27+ IHistoryTransform ? historyTransform = null ,
28+ ITextStreamTransform ? outputTransform = null ) =>
29+ new LLamaExecutorChatClient (
30+ executor ?? throw new ArgumentNullException ( nameof ( executor ) ) ,
31+ historyTransform ,
32+ outputTransform ) ;
33+
34+ private sealed class LLamaExecutorChatClient (
35+ ILLamaExecutor executor ,
36+ IHistoryTransform ? historyTransform = null ,
37+ ITextStreamTransform ? outputTransform = null ) : IChatClient
38+ {
39+ private static readonly InferenceParams s_defaultParams = new ( ) ;
40+ private static readonly DefaultSamplingPipeline s_defaultPipeline = new ( ) ;
41+ private static readonly string [ ] s_antiPrompts = [ "User:" , "Assistant:" , "System:" ] ;
42+ [ ThreadStatic ]
43+ private static Random ? t_random ;
44+
45+ private readonly ILLamaExecutor _executor = executor ;
46+ private readonly IHistoryTransform _historyTransform = historyTransform ?? new AppendAssistantHistoryTransform ( ) ;
47+ private readonly ITextStreamTransform _outputTransform = outputTransform ??
48+ new LLamaTransforms . KeywordTextOutputStreamTransform ( s_antiPrompts ) ;
49+
50+ /// <inheritdoc/>
51+ public ChatClientMetadata Metadata { get ; } = new ( nameof ( LLamaExecutorChatClient ) ) ;
52+
53+ /// <inheritdoc/>
54+ public void Dispose ( ) { }
55+
56+ /// <inheritdoc/>
57+ public TService ? GetService < TService > ( object ? key = null ) where TService : class =>
58+ typeof ( TService ) == typeof ( ILLamaExecutor ) ? ( TService ) _executor :
59+ this as TService ;
60+
61+ /// <inheritdoc/>
62+ public async Task < ChatCompletion > CompleteAsync (
63+ IList < ChatMessage > chatMessages , ChatOptions ? options = null , CancellationToken cancellationToken = default )
64+ {
65+ var result = _executor . InferAsync ( CreatePrompt ( chatMessages ) , CreateInferenceParams ( options ) , cancellationToken ) ;
66+
67+ StringBuilder text = new ( ) ;
68+ await foreach ( var token in _outputTransform . TransformAsync ( result ) )
69+ {
70+ text . Append ( token ) ;
71+ }
72+
73+ return new ( new ChatMessage ( ChatRole . Assistant , text . ToString ( ) ) )
74+ {
75+ CreatedAt = DateTime . UtcNow ,
76+ } ;
77+ }
78+
79+ /// <inheritdoc/>
80+ public async IAsyncEnumerable < StreamingChatCompletionUpdate > CompleteStreamingAsync (
81+ IList < ChatMessage > chatMessages , ChatOptions ? options = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
82+ {
83+ var result = _executor . InferAsync ( CreatePrompt ( chatMessages ) , CreateInferenceParams ( options ) , cancellationToken ) ;
84+
85+ await foreach ( var token in _outputTransform . TransformAsync ( result ) )
86+ {
87+ yield return new ( )
88+ {
89+ CreatedAt = DateTime . UtcNow ,
90+ Role = ChatRole . Assistant ,
91+ Text = token ,
92+ } ;
93+ }
94+ }
95+
96+ /// <summary>Format the chat messages into a string prompt.</summary>
97+ private string CreatePrompt ( IList < ChatMessage > messages )
98+ {
99+ if ( messages is null )
100+ {
101+ throw new ArgumentNullException ( nameof ( messages ) ) ;
102+ }
103+
104+ ChatHistory history = new ( ) ;
105+
106+ if ( _executor is not StatefulExecutorBase seb ||
107+ seb . GetStateData ( ) is InteractiveExecutor . InteractiveExecutorState { IsPromptRun : true } )
108+ {
109+ foreach ( var message in messages )
110+ {
111+ history . AddMessage (
112+ message . Role == ChatRole . System ? AuthorRole . System :
113+ message . Role == ChatRole . Assistant ? AuthorRole . Assistant :
114+ AuthorRole . User ,
115+ string . Concat ( message . Contents . OfType < TextContent > ( ) ) ) ;
116+ }
117+ }
118+ else
119+ {
120+ // Stateless executor with IsPromptRun = false: use only the last message.
121+ history . AddMessage ( AuthorRole . User , string . Concat ( messages . LastOrDefault ( ) ? . Contents . OfType < TextContent > ( ) ?? [ ] ) ) ;
122+ }
123+
124+ return _historyTransform . HistoryToText ( history ) ;
125+ }
126+
127+ /// <summary>Convert the chat options to inference parameters.</summary>
128+ private static InferenceParams ? CreateInferenceParams ( ChatOptions ? options )
129+ {
130+ List < string > antiPrompts = new ( s_antiPrompts ) ;
131+ if ( options ? . AdditionalProperties ? . TryGetValue ( nameof ( InferenceParams . AntiPrompts ) , out IReadOnlyList < string > ? anti ) is true )
132+ {
133+ antiPrompts . AddRange ( anti ) ;
134+ }
135+
136+ return new ( )
137+ {
138+ AntiPrompts = antiPrompts ,
139+ TokensKeep = options ? . AdditionalProperties ? . TryGetValue ( nameof ( InferenceParams . TokensKeep ) , out int tk ) is true ? tk : s_defaultParams . TokensKeep ,
140+ MaxTokens = options ? . MaxOutputTokens ?? 256 , // arbitrary upper limit
141+ SamplingPipeline = new DefaultSamplingPipeline ( )
142+ {
143+ AlphaFrequency = options ? . AdditionalProperties ? . TryGetValue ( nameof ( DefaultSamplingPipeline . AlphaFrequency ) , out float af ) is true ? af : s_defaultPipeline . AlphaFrequency ,
144+ AlphaPresence = options ? . AdditionalProperties ? . TryGetValue ( nameof ( DefaultSamplingPipeline . AlphaPresence ) , out float ap ) is true ? ap : s_defaultPipeline . AlphaPresence ,
145+ PenalizeEOS = options ? . AdditionalProperties ? . TryGetValue ( nameof ( DefaultSamplingPipeline . PenalizeEOS ) , out bool eos ) is true ? eos : s_defaultPipeline . PenalizeEOS ,
146+ PenalizeNewline = options ? . AdditionalProperties ? . TryGetValue ( nameof ( DefaultSamplingPipeline . PenalizeNewline ) , out bool pnl ) is true ? pnl : s_defaultPipeline . PenalizeNewline ,
147+ RepeatPenalty = options ? . AdditionalProperties ? . TryGetValue ( nameof ( DefaultSamplingPipeline . RepeatPenalty ) , out float rp ) is true ? rp : s_defaultPipeline . RepeatPenalty ,
148+ RepeatPenaltyCount = options ? . AdditionalProperties ? . TryGetValue ( nameof ( DefaultSamplingPipeline . RepeatPenaltyCount ) , out int rpc ) is true ? rpc : s_defaultPipeline . RepeatPenaltyCount ,
149+ Grammar = options ? . AdditionalProperties ? . TryGetValue ( nameof ( DefaultSamplingPipeline . Grammar ) , out Grammar ? g ) is true ? g : s_defaultPipeline . Grammar ,
150+ MinKeep = options ? . AdditionalProperties ? . TryGetValue ( nameof ( DefaultSamplingPipeline . MinKeep ) , out int mk ) is true ? mk : s_defaultPipeline . MinKeep ,
151+ MinP = options ? . AdditionalProperties ? . TryGetValue ( nameof ( DefaultSamplingPipeline . MinP ) , out float mp ) is true ? mp : s_defaultPipeline . MinP ,
152+ Seed = options ? . AdditionalProperties ? . TryGetValue ( nameof ( DefaultSamplingPipeline . Seed ) , out uint seed ) is true ? seed : ( uint ) ( t_random ??= new ( ) ) . Next ( ) ,
153+ TailFreeZ = options ? . AdditionalProperties ? . TryGetValue ( nameof ( DefaultSamplingPipeline . TailFreeZ ) , out float tfz ) is true ? tfz : s_defaultPipeline . TailFreeZ ,
154+ Temperature = options ? . Temperature ?? 0 ,
155+ TopP = options ? . TopP ?? 0 ,
156+ TopK = options ? . TopK ?? s_defaultPipeline . TopK ,
157+ TypicalP = options ? . AdditionalProperties ? . TryGetValue ( nameof ( DefaultSamplingPipeline . TypicalP ) , out float tp ) is true ? tp : s_defaultPipeline . TypicalP ,
158+ } ,
159+ } ;
160+ }
161+
162+ /// <summary>A default transform that appends "Assistant: " to the end.</summary>
163+ private sealed class AppendAssistantHistoryTransform : LLamaTransforms . DefaultHistoryTransform
164+ {
165+ public override string HistoryToText ( ChatHistory history ) =>
166+ $ "{ base . HistoryToText ( history ) } { AuthorRole . Assistant } : ";
167+ }
168+ }
169+ }
0 commit comments