|  | 
|  | 1 | +// Licensed to the .NET Foundation under one or more agreements. | 
|  | 2 | +// The .NET Foundation licenses this file to you under the MIT license. | 
|  | 3 | +// See the LICENSE file in the project root for more information. | 
|  | 4 | + | 
|  | 5 | +using System; | 
|  | 6 | +using System.Collections.Generic; | 
|  | 7 | +using System.Linq; | 
|  | 8 | +using System.Runtime.CompilerServices; | 
|  | 9 | +using System.Threading; | 
|  | 10 | +using System.Threading.Tasks; | 
|  | 11 | +using Microsoft.Extensions.AI; | 
|  | 12 | +using Microsoft.ML.Tokenizers; | 
|  | 13 | +using static TorchSharp.torch; | 
|  | 14 | + | 
|  | 15 | +namespace Microsoft.ML.GenAI.Core; | 
|  | 16 | + | 
|  | 17 | +public abstract class CausalLMPipelineChatClient<TTokenizer, TCausalLMModel> : IChatClient | 
|  | 18 | +    where TTokenizer : Tokenizer | 
|  | 19 | +    where TCausalLMModel : nn.Module<CausalLMModelInput, CausalLMModelOutput> | 
|  | 20 | +{ | 
|  | 21 | +    private readonly ICausalLMPipeline<TTokenizer, TCausalLMModel> _pipeline; | 
|  | 22 | +    private readonly IMEAIChatTemplateBuilder _chatTemplateBuilder; | 
|  | 23 | + | 
|  | 24 | +    public CausalLMPipelineChatClient( | 
|  | 25 | +        ICausalLMPipeline<TTokenizer, TCausalLMModel> pipeline, | 
|  | 26 | +        IMEAIChatTemplateBuilder chatTemplateBuilder, | 
|  | 27 | +        ChatClientMetadata? metadata = null) | 
|  | 28 | +    { | 
|  | 29 | +        var classNameWithType = $"{nameof(CausalLMPipelineChatClient<TTokenizer, TCausalLMModel>)}<{typeof(TTokenizer).Name}, {typeof(TCausalLMModel).Name}>"; | 
|  | 30 | +        Metadata ??= new ChatClientMetadata(providerName: classNameWithType, modelId: typeof(TCausalLMModel).Name); | 
|  | 31 | +        _chatTemplateBuilder = chatTemplateBuilder; | 
|  | 32 | +        _pipeline = pipeline; | 
|  | 33 | +    } | 
|  | 34 | + | 
|  | 35 | +    public ChatClientMetadata Metadata { get; } | 
|  | 36 | + | 
|  | 37 | +    public virtual Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) | 
|  | 38 | +    { | 
|  | 39 | +        var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options); | 
|  | 40 | +        var stopSequences = options?.StopSequences ?? Array.Empty<string>(); | 
|  | 41 | + | 
|  | 42 | +        var output = _pipeline.Generate( | 
|  | 43 | +            prompt, | 
|  | 44 | +            maxLen: options?.MaxOutputTokens ?? 1024, | 
|  | 45 | +            temperature: options?.Temperature ?? 0.7f, | 
|  | 46 | +            stopSequences: stopSequences.ToArray()) ?? throw new InvalidOperationException("Failed to generate a reply."); | 
|  | 47 | + | 
|  | 48 | +        var chatMessage = new ChatMessage(ChatRole.Assistant, output); | 
|  | 49 | +        return Task.FromResult(new ChatCompletion([chatMessage]) | 
|  | 50 | +        { | 
|  | 51 | +            CreatedAt = DateTime.UtcNow, | 
|  | 52 | +            FinishReason = ChatFinishReason.Stop, | 
|  | 53 | +        }); | 
|  | 54 | +    } | 
|  | 55 | + | 
|  | 56 | +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously | 
|  | 57 | +    public virtual async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync( | 
|  | 58 | +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously | 
|  | 59 | +        IList<ChatMessage> chatMessages, | 
|  | 60 | +        ChatOptions? options = null, | 
|  | 61 | +        [EnumeratorCancellation] CancellationToken cancellationToken = default) | 
|  | 62 | +    { | 
|  | 63 | +        var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options); | 
|  | 64 | +        var stopSequences = options?.StopSequences ?? Array.Empty<string>(); | 
|  | 65 | + | 
|  | 66 | +        foreach (var output in _pipeline.GenerateStreaming( | 
|  | 67 | +            prompt, | 
|  | 68 | +            maxLen: options?.MaxOutputTokens ?? 1024, | 
|  | 69 | +            temperature: options?.Temperature ?? 0.7f, | 
|  | 70 | +            stopSequences: stopSequences.ToArray())) | 
|  | 71 | +        { | 
|  | 72 | +            yield return new StreamingChatCompletionUpdate | 
|  | 73 | +            { | 
|  | 74 | +                Role = ChatRole.Assistant, | 
|  | 75 | +                Text = output, | 
|  | 76 | +                CreatedAt = DateTime.UtcNow, | 
|  | 77 | +            }; | 
|  | 78 | +        } | 
|  | 79 | +    } | 
|  | 80 | + | 
|  | 81 | +    public virtual void Dispose() | 
|  | 82 | +    { | 
|  | 83 | +    } | 
|  | 84 | + | 
|  | 85 | +    public virtual TService? GetService<TService>(object? key = null) where TService : class | 
|  | 86 | +    { | 
|  | 87 | +        return null; | 
|  | 88 | +    } | 
|  | 89 | +} | 
0 commit comments