Skip to content

Commit 7559adf

Browse files
committed
Add Llama-3.2-1B pipeline
1 parent 5d83acf commit 7559adf

File tree

4 files changed

+276
-1
lines changed

4 files changed

+276
-1
lines changed

Directory.Build.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<Project>
22
<PropertyGroup>
3-
<Version>0.1.31</Version>
3+
<Version>0.1.33</Version>
44
<Company>TensorStack</Company>
55
<Copyright>TensorStack - 2025</Copyright>
66
<RepositoryUrl>https://github.com/TensorStack-AI/TensorStack</RepositoryUrl>
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
using TensorStack.TextGeneration.Common;
2+
3+
namespace TensorStack.TextGeneration.Pipelines.Llama
4+
{
5+
public record LlamaConfig : TransformerConfig
6+
{
7+
}
8+
}
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
4+
using System;
5+
using System.IO;
6+
using System.Linq;
7+
using System.Threading;
8+
using System.Threading.Tasks;
9+
using TensorStack.Common;
10+
using TensorStack.Common.Pipeline;
11+
using TensorStack.Common.Tensor;
12+
using TensorStack.TextGeneration.Cache;
13+
using TensorStack.TextGeneration.Common;
14+
using TensorStack.TextGeneration.Processing;
15+
using TensorStack.TextGeneration.Tokenizers;
16+
17+
namespace TensorStack.TextGeneration.Pipelines.Llama
18+
{
19+
public class LlamaPipeline : DecoderPipeline<GenerateOptions>,
20+
IPipeline<GenerateResult, GenerateOptions, GenerateProgress>,
21+
IPipeline<GenerateResult[], SearchOptions, GenerateProgress>
22+
{
23+
/// <summary>
24+
/// Initializes a new instance of the <see cref="LlamaPipeline"/> class.
25+
/// </summary>
26+
/// <param name="tokenizerConfig">The tokenizer configuration.</param>
27+
/// <param name="decoderConfig">The decoder configuration.</param>
28+
public LlamaPipeline(LlamaConfig configuration)
29+
: base(configuration.Tokenizer, configuration.DecoderConfig)
30+
{
31+
Configuration = configuration;
32+
}
33+
34+
public LlamaConfig Configuration { get; }
35+
36+
37+
/// <summary>
38+
/// Runs the GreedySearch inference
39+
/// </summary>
40+
/// <param name="options">The options.</param>
41+
/// <param name="cancellationToken">The cancellation token.</param>
42+
/// <returns></returns>
43+
public virtual async Task<GenerateResult> RunAsync(GenerateOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
44+
{
45+
await TokenizePromptAsync(options);
46+
var sequence = await GreedySearchAsync(options, progressCallback, cancellationToken);
47+
using (sequence)
48+
{
49+
return new GenerateResult
50+
{
51+
Score = sequence.Score,
52+
Result = Tokenizer.Decode(sequence.Tokens)
53+
};
54+
}
55+
}
56+
57+
58+
/// <summary>
59+
/// Runs the BeamSearch inference
60+
/// </summary>
61+
/// <param name="options">The options.</param>
62+
/// <param name="progressCallback">The progress callback.</param>
63+
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
64+
public async Task<GenerateResult[]> RunAsync(SearchOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
65+
{
66+
await TokenizePromptAsync(options);
67+
68+
var sequences = await BeamSearchAsync(options, progressCallback, cancellationToken);
69+
var results = new GenerateResult[sequences.Length];
70+
for (int beam = 0; beam < sequences.Length; beam++)
71+
{
72+
var sequence = sequences[beam];
73+
using (sequence)
74+
{
75+
results[beam] = new GenerateResult
76+
{
77+
Beam = beam,
78+
Score = sequence.Score,
79+
PenaltyScore = sequence.PenaltyScore,
80+
Result = Tokenizer.Decode(sequence.Tokens)
81+
};
82+
}
83+
}
84+
return results;
85+
}
86+
87+
88+
/// <summary>
89+
/// Gets the token processors.
90+
/// </summary>
91+
/// <param name="options">The options.</param>
92+
/// <returns>ITokenProcessor[].</returns>
93+
protected override ITokenProcessor[] GetTokenProcessors(GenerateOptions options)
94+
{
95+
return
96+
[
97+
new EOSTokenProcessor
98+
(
99+
options.MinLength, // min length
100+
Tokenizer.EOS
101+
),
102+
new MaxLengthTokenProcessor(options.MaxLength)
103+
];
104+
}
105+
106+
107+
/// <summary>
108+
/// Initialize the Decoder cache
109+
/// </summary>
110+
/// <param name="options">The options.</param>
111+
/// <returns>A Task&lt;Sequence&gt; representing the asynchronous operation.</returns>
112+
protected override async Task<Sequence> InitializeAsync(GenerateOptions options)
113+
{
114+
var modelMetadata = await Decoder.LoadAsync();
115+
var kvCache = new KVCacheDecoder(modelMetadata, DecoderConfig.NumHeads, DecoderConfig.NumLayers, DecoderConfig.HiddenSize, DecoderConfig.NumKVHeads, options.MaxLength);
116+
var sequence = new Sequence(kvCache, Tokenizer.BOS);
117+
sequence.Initialize(0);
118+
119+
var position = TokenizerOutput.Length;
120+
var inputIds = TokenizerOutput.InputIds;
121+
var positionIds = GetPositionIds(modelMetadata, 0, position);
122+
var attentionMask = new Tensor<long>([1, position], 1);
123+
RunDecoderInternal(modelMetadata, sequence, inputIds, positionIds, attentionMask, false);
124+
return sequence;
125+
}
126+
127+
128+
/// <summary>
129+
/// Run decoder model
130+
/// </summary>
131+
/// <param name="sequence">The sequence.</param>
132+
/// <returns>A Task&lt;Tensor`1&gt; representing the asynchronous operation.</returns>
133+
protected override async Task<Tensor<float>> RunDecoderAsync(Sequence sequence)
134+
{
135+
var modelMetadata = await Decoder.LoadAsync();
136+
var position = TokenizerOutput.Length + sequence.Tokens.Count;
137+
var inputIds = new Tensor<long>([1, 1], sequence.Tokens[^1]);
138+
var positionIds = GetPositionIds(modelMetadata, position);
139+
var attentionMask = new Tensor<long>([1, position], 1);
140+
return RunDecoderInternal(modelMetadata, sequence, inputIds, positionIds, attentionMask, true);
141+
}
142+
143+
144+
/// <summary>
145+
/// Runs the decoder
146+
/// </summary>
147+
/// <param name="modelMetadata">The model metadata.</param>
148+
/// <param name="sequence">The sequence.</param>
149+
/// <param name="inputIds">The input ids.</param>
150+
/// <param name="positionIds">The position ids.</param>
151+
/// <param name="attentionMask">The attention mask.</param>
152+
/// <param name="useBranchCache">if set to <c>true</c> [use branch cache].</param>
153+
private Tensor<float> RunDecoderInternal(ModelMetadata modelMetadata, Sequence sequence, Tensor<long> inputIds, Tensor<long> positionIds, Tensor<long> attentionMask, bool useBranchCache)
154+
{
155+
using (var parameters = new ModelParameters(modelMetadata))
156+
{
157+
// Inputs
158+
parameters.AddInput(inputIds);
159+
parameters.AddInput(attentionMask);
160+
if (positionIds != null)
161+
parameters.AddInput(positionIds);
162+
163+
foreach (var pastKeyValue in sequence.Cache)
164+
parameters.AddInput(pastKeyValue, false);
165+
166+
// Outputs
167+
foreach (var output in modelMetadata.Outputs)
168+
parameters.AddOutput();
169+
170+
// Result
171+
var modelResult = Decoder.RunInference(parameters);
172+
using (var logitsResult = modelResult[0])
173+
{
174+
var dimension = logitsResult.GetDimensions();
175+
var logits = logitsResult.ToTensor(dimension[1..]);
176+
var presentKeyValues = modelResult.ToArray()[1..];
177+
178+
sequence.UpdateCache(presentKeyValues, useBranchCache);
179+
return logits;
180+
}
181+
}
182+
}
183+
184+
185+
/// <summary>
186+
/// Creates the LlamaPipeline
187+
/// </summary>
188+
/// <param name="provider">The provider.</param>
189+
/// <param name="modelPath">The model path.</param>
190+
/// <param name="tokenizerModel">The tokenizer model.</param>
191+
/// <param name="decoderModel">The decoder model.</param>
192+
/// <returns>Phi3Pipeline.</returns>
193+
public static LlamaPipeline Create(ExecutionProvider provider, string modelPath, string model = "model.onnx")
194+
{
195+
// Llama-3.2-1B
196+
var numHeads = 32;
197+
var numLayers = 16;
198+
var hiddenSize = 2048;
199+
var numKVHeads = 8;
200+
var vocabSize = 128256;
201+
var config = new LlamaConfig
202+
{
203+
Tokenizer = new BPETokenizer(new TokenizerConfig
204+
{
205+
BOS = 128000,
206+
EOS = 128001,
207+
Path = modelPath
208+
}),
209+
DecoderConfig = new DecoderConfig
210+
{
211+
Path = Path.Combine(modelPath, model),
212+
VocabSize = vocabSize,
213+
NumHeads = numHeads,
214+
NumLayers = numLayers,
215+
HiddenSize = hiddenSize,
216+
NumKVHeads = numKVHeads
217+
}
218+
};
219+
220+
config.DecoderConfig.SetProvider(provider);
221+
return new LlamaPipeline(config);
222+
}
223+
224+
}
225+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Llama Pipeline
2+
3+
Llama-3.2-1B: https://huggingface.co/TensorStack/Llama-3.2-1B
4+
5+
### Greedy
6+
```csharp
7+
var provider = Provider.GetProvider();
8+
var modelPath = "M:\\Models\\Llama-3.2-1B";
9+
var pipeline = LlamaPipeline.Create(provider, modelPath, PhiType.Mini);
10+
var options = new GenerateOptions
11+
{
12+
Prompt = "What is an apple?"
13+
};
14+
15+
var generateResult = await pipeline.RunAsync(options);
16+
System.Console.WriteLine(generateResult.Result);
17+
```
18+
19+
### Beam Search
20+
```csharp
21+
var provider = Provider.GetProvider();
22+
var modelPath = "M:\\Models\\Llama-3.2-1B";
23+
var pipeline = LlamaPipeline.Create(provider, modelPath, PhiType.Mini);
24+
var options = new SearchOptions
25+
{
26+
Seed = 0,
27+
TopK = 50,
28+
Beams = 3,
29+
TopP = 0.9f,
30+
Temperature = 1f,
31+
LengthPenalty = -1f,
32+
DiversityLength = 20,
33+
NoRepeatNgramSize = 3,
34+
EarlyStopping = EarlyStopping.None,
35+
Prompt = "What is an apple?"
36+
};
37+
38+
foreach (var beamResult in await pipeline.RunAsync(options))
39+
{
40+
System.Console.WriteLine(beamResult.Result);
41+
}
42+
```

0 commit comments

Comments
 (0)