1+ // Copyright (c) TensorStack. All rights reserved.
2+ // Licensed under the Apache 2.0 License.
3+
4+ using System ;
5+ using System . IO ;
6+ using System . Linq ;
7+ using System . Threading ;
8+ using System . Threading . Tasks ;
9+ using TensorStack . Common ;
10+ using TensorStack . Common . Pipeline ;
11+ using TensorStack . Common . Tensor ;
12+ using TensorStack . TextGeneration . Cache ;
13+ using TensorStack . TextGeneration . Common ;
14+ using TensorStack . TextGeneration . Processing ;
15+ using TensorStack . TextGeneration . Tokenizers ;
16+
17+ namespace TensorStack . TextGeneration . Pipelines . Llama
18+ {
19+ public class LlamaPipeline : DecoderPipeline < GenerateOptions > ,
20+ IPipeline < GenerateResult , GenerateOptions , GenerateProgress > ,
21+ IPipeline < GenerateResult [ ] , SearchOptions , GenerateProgress >
22+ {
23+ /// <summary>
24+ /// Initializes a new instance of the <see cref="LlamaPipeline"/> class.
25+ /// </summary>
26+ /// <param name="tokenizerConfig">The tokenizer configuration.</param>
27+ /// <param name="decoderConfig">The decoder configuration.</param>
28+ public LlamaPipeline ( LlamaConfig configuration )
29+ : base ( configuration . Tokenizer , configuration . DecoderConfig )
30+ {
31+ Configuration = configuration ;
32+ }
33+
34+ public LlamaConfig Configuration { get ; }
35+
36+
37+ /// <summary>
38+ /// Runs the GreedySearch inference
39+ /// </summary>
40+ /// <param name="options">The options.</param>
41+ /// <param name="cancellationToken">The cancellation token.</param>
42+ /// <returns></returns>
43+ public virtual async Task < GenerateResult > RunAsync ( GenerateOptions options , IProgress < GenerateProgress > progressCallback = null , CancellationToken cancellationToken = default )
44+ {
45+ await TokenizePromptAsync ( options ) ;
46+ var sequence = await GreedySearchAsync ( options , progressCallback , cancellationToken ) ;
47+ using ( sequence )
48+ {
49+ return new GenerateResult
50+ {
51+ Score = sequence . Score ,
52+ Result = Tokenizer . Decode ( sequence . Tokens )
53+ } ;
54+ }
55+ }
56+
57+
58+ /// <summary>
59+ /// Runs the BeamSearch inference
60+ /// </summary>
61+ /// <param name="options">The options.</param>
62+ /// <param name="progressCallback">The progress callback.</param>
63+ /// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
64+ public async Task < GenerateResult [ ] > RunAsync ( SearchOptions options , IProgress < GenerateProgress > progressCallback = null , CancellationToken cancellationToken = default )
65+ {
66+ await TokenizePromptAsync ( options ) ;
67+
68+ var sequences = await BeamSearchAsync ( options , progressCallback , cancellationToken ) ;
69+ var results = new GenerateResult [ sequences . Length ] ;
70+ for ( int beam = 0 ; beam < sequences . Length ; beam ++ )
71+ {
72+ var sequence = sequences [ beam ] ;
73+ using ( sequence )
74+ {
75+ results [ beam ] = new GenerateResult
76+ {
77+ Beam = beam ,
78+ Score = sequence . Score ,
79+ PenaltyScore = sequence . PenaltyScore ,
80+ Result = Tokenizer . Decode ( sequence . Tokens )
81+ } ;
82+ }
83+ }
84+ return results ;
85+ }
86+
87+
88+ /// <summary>
89+ /// Gets the token processors.
90+ /// </summary>
91+ /// <param name="options">The options.</param>
92+ /// <returns>ITokenProcessor[].</returns>
93+ protected override ITokenProcessor [ ] GetTokenProcessors ( GenerateOptions options )
94+ {
95+ return
96+ [
97+ new EOSTokenProcessor
98+ (
99+ options . MinLength , // min length
100+ Tokenizer . EOS
101+ ) ,
102+ new MaxLengthTokenProcessor ( options . MaxLength )
103+ ] ;
104+ }
105+
106+
107+ /// <summary>
108+ /// Initialize the Decoder cache
109+ /// </summary>
110+ /// <param name="options">The options.</param>
111+ /// <returns>A Task<Sequence> representing the asynchronous operation.</returns>
112+ protected override async Task < Sequence > InitializeAsync ( GenerateOptions options )
113+ {
114+ var modelMetadata = await Decoder . LoadAsync ( ) ;
115+ var kvCache = new KVCacheDecoder ( modelMetadata , DecoderConfig . NumHeads , DecoderConfig . NumLayers , DecoderConfig . HiddenSize , DecoderConfig . NumKVHeads , options . MaxLength ) ;
116+ var sequence = new Sequence ( kvCache , Tokenizer . BOS ) ;
117+ sequence . Initialize ( 0 ) ;
118+
119+ var position = TokenizerOutput . Length ;
120+ var inputIds = TokenizerOutput . InputIds ;
121+ var positionIds = GetPositionIds ( modelMetadata , 0 , position ) ;
122+ var attentionMask = new Tensor < long > ( [ 1 , position ] , 1 ) ;
123+ RunDecoderInternal ( modelMetadata , sequence , inputIds , positionIds , attentionMask , false ) ;
124+ return sequence ;
125+ }
126+
127+
128+ /// <summary>
129+ /// Run decoder model
130+ /// </summary>
131+ /// <param name="sequence">The sequence.</param>
132+ /// <returns>A Task<Tensor`1> representing the asynchronous operation.</returns>
133+ protected override async Task < Tensor < float > > RunDecoderAsync ( Sequence sequence )
134+ {
135+ var modelMetadata = await Decoder . LoadAsync ( ) ;
136+ var position = TokenizerOutput . Length + sequence . Tokens . Count ;
137+ var inputIds = new Tensor < long > ( [ 1 , 1 ] , sequence . Tokens [ ^ 1 ] ) ;
138+ var positionIds = GetPositionIds ( modelMetadata , position ) ;
139+ var attentionMask = new Tensor < long > ( [ 1 , position ] , 1 ) ;
140+ return RunDecoderInternal ( modelMetadata , sequence , inputIds , positionIds , attentionMask , true ) ;
141+ }
142+
143+
144+ /// <summary>
145+ /// Runs the decoder
146+ /// </summary>
147+ /// <param name="modelMetadata">The model metadata.</param>
148+ /// <param name="sequence">The sequence.</param>
149+ /// <param name="inputIds">The input ids.</param>
150+ /// <param name="positionIds">The position ids.</param>
151+ /// <param name="attentionMask">The attention mask.</param>
152+ /// <param name="useBranchCache">if set to <c>true</c> [use branch cache].</param>
153+ private Tensor < float > RunDecoderInternal ( ModelMetadata modelMetadata , Sequence sequence , Tensor < long > inputIds , Tensor < long > positionIds , Tensor < long > attentionMask , bool useBranchCache )
154+ {
155+ using ( var parameters = new ModelParameters ( modelMetadata ) )
156+ {
157+ // Inputs
158+ parameters . AddInput ( inputIds ) ;
159+ parameters . AddInput ( attentionMask ) ;
160+ if ( positionIds != null )
161+ parameters . AddInput ( positionIds ) ;
162+
163+ foreach ( var pastKeyValue in sequence . Cache )
164+ parameters . AddInput ( pastKeyValue , false ) ;
165+
166+ // Outputs
167+ foreach ( var output in modelMetadata . Outputs )
168+ parameters . AddOutput ( ) ;
169+
170+ // Result
171+ var modelResult = Decoder . RunInference ( parameters ) ;
172+ using ( var logitsResult = modelResult [ 0 ] )
173+ {
174+ var dimension = logitsResult . GetDimensions ( ) ;
175+ var logits = logitsResult . ToTensor ( dimension [ 1 ..] ) ;
176+ var presentKeyValues = modelResult . ToArray ( ) [ 1 ..] ;
177+
178+ sequence . UpdateCache ( presentKeyValues , useBranchCache ) ;
179+ return logits ;
180+ }
181+ }
182+ }
183+
184+
185+ /// <summary>
186+ /// Creates the LlamaPipeline
187+ /// </summary>
188+ /// <param name="provider">The provider.</param>
189+ /// <param name="modelPath">The model path.</param>
190+ /// <param name="tokenizerModel">The tokenizer model.</param>
191+ /// <param name="decoderModel">The decoder model.</param>
192+ /// <returns>Phi3Pipeline.</returns>
193+ public static LlamaPipeline Create ( ExecutionProvider provider , string modelPath , string model = "model.onnx" )
194+ {
195+ // Llama-3.2-1B
196+ var numHeads = 32 ;
197+ var numLayers = 16 ;
198+ var hiddenSize = 2048 ;
199+ var numKVHeads = 8 ;
200+ var vocabSize = 128256 ;
201+ var config = new LlamaConfig
202+ {
203+ Tokenizer = new BPETokenizer ( new TokenizerConfig
204+ {
205+ BOS = 128000 ,
206+ EOS = 128001 ,
207+ Path = modelPath
208+ } ) ,
209+ DecoderConfig = new DecoderConfig
210+ {
211+ Path = Path . Combine ( modelPath , model ) ,
212+ VocabSize = vocabSize ,
213+ NumHeads = numHeads ,
214+ NumLayers = numLayers ,
215+ HiddenSize = hiddenSize ,
216+ NumKVHeads = numKVHeads
217+ }
218+ } ;
219+
220+ config . DecoderConfig . SetProvider ( provider ) ;
221+ return new LlamaPipeline ( config ) ;
222+ }
223+
224+ }
225+ }
0 commit comments