Skip to content

Commit e881e84

Browse files
committed
TextGeneration/StableDiffusion merge
1 parent 7559adf commit e881e84

33 files changed

+475
-311
lines changed

BuildRelease.bat

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ dotnet pack TensorStack.Extractors/TensorStack.Extractors.csproj -c Release
4040
dotnet build TensorStack.Upscaler/TensorStack.Upscaler.csproj -c Release
4141
dotnet pack TensorStack.Upscaler/TensorStack.Upscaler.csproj -c Release
4242

43-
dotnet build TensorStack.StableDiffusion/TensorStack.StableDiffusion.csproj -c Release
44-
dotnet pack TensorStack.StableDiffusion/TensorStack.StableDiffusion.csproj -c Release
45-
4643
dotnet build TensorStack.TextGeneration/TensorStack.TextGeneration.csproj -c Release
4744
dotnet pack TensorStack.TextGeneration/TensorStack.TextGeneration.csproj -c Release
4845

46+
dotnet build TensorStack.StableDiffusion/TensorStack.StableDiffusion.csproj -c Release
47+
dotnet pack TensorStack.StableDiffusion/TensorStack.StableDiffusion.csproj -c Release
48+
4949
dotnet build TensorStack.WPF/TensorStack.WPF.csproj -c Release
5050
dotnet pack TensorStack.WPF/TensorStack.WPF.csproj -c Release

TensorStack.Common/Extensions/Extensions.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ public static T[] PadOrTruncate<T>(this T[] inputs, T padValue, int requiredLeng
123123
}
124124

125125

126+
public static T[] PadOrTruncate<T>(this ReadOnlySpan<T> inputs, T padValue, int requiredLength)
127+
{
128+
return PadOrTruncate(inputs.ToArray(), padValue, requiredLength);
129+
}
130+
131+
126132
public static T[] Pad<T>(this T[] inputs, T padValue, int requiredLength)
127133
{
128134
int count = inputs.Length;
@@ -137,6 +143,11 @@ public static T[] Pad<T>(this T[] inputs, T padValue, int requiredLength)
137143
return result;
138144
}
139145

146+
public static T[] Pad<T>(this ReadOnlySpan<T> inputs, T padValue, int requiredLength)
147+
{
148+
return Pad(inputs.ToArray(), padValue, requiredLength);
149+
}
150+
140151

141152
[MethodImpl(MethodImplOptions.AggressiveInlining)]
142153
public static float ZeroIfNan(this float value)

TensorStack.StableDiffusion/Common/TokenizerResult.cs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,24 @@
44

55
namespace TensorStack.StableDiffusion.Common
66
{
7-
public record TokenizerResult
8-
{
9-
public TokenizerResult(long[] inputIds, long[] attentionMask)
10-
{
11-
InputIds = inputIds;
12-
AttentionMask = attentionMask;
13-
Weights = [.. Enumerable.Repeat(1f, inputIds.Length)];
14-
}
7+
//public record TokenizerResult
8+
//{
9+
// public TokenizerResult(long[] inputIds, long[] attentionMask)
10+
// {
11+
// InputIds = inputIds;
12+
// AttentionMask = attentionMask;
13+
// Weights = [.. Enumerable.Repeat(1f, inputIds.Length)];
14+
// }
1515

16-
public TokenizerResult(long[] inputIds, long[] attentionMask, float[] weights)
17-
{
18-
InputIds = inputIds;
19-
AttentionMask = attentionMask;
20-
Weights = weights;
21-
}
16+
// public TokenizerResult(long[] inputIds, long[] attentionMask, float[] weights)
17+
// {
18+
// InputIds = inputIds;
19+
// AttentionMask = attentionMask;
20+
// Weights = weights;
21+
// }
2222

23-
public long[] InputIds { get; set; }
24-
public long[] AttentionMask { get; set; }
25-
public float[] Weights { get; set; }
26-
}
23+
// public long[] InputIds { get; set; }
24+
// public long[] AttentionMask { get; set; }
25+
// public float[] Weights { get; set; }
26+
//}
2727
}

TensorStack.StableDiffusion/Config/TokenizerConfig.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace TensorStack.StableDiffusion.Config
66
{
7-
public record TokenizerConfig : ModelConfig
8-
{
9-
}
7+
//public record TokenizerConfig : ModelConfig
8+
//{
9+
//}
1010
}

TensorStack.StableDiffusion/Helpers/PromptParser.cs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
using TensorStack.Common.Tensor;
1010
using TensorStack.StableDiffusion.Common;
1111
using TensorStack.StableDiffusion.Models;
12-
using TensorStack.StableDiffusion.Tokenizers;
12+
using TensorStack.TextGeneration.Tokenizers;
1313

1414
namespace TensorStack.StableDiffusion.Helpers
1515
{
@@ -34,8 +34,8 @@ public static async Task<TokenizerResult> TokenizePromptAsync(CLIPTokenizer toke
3434
foreach (var fragment in fragments)
3535
{
3636
var fragmentTokens = await tokenizer.EncodeAsync(fragment.Text, false);
37-
tokenIds.AddRange(fragmentTokens.InputIds);
38-
tokenWeights.AddRange(Enumerable.Repeat(fragment.Weight, fragmentTokens.InputIds.Length));
37+
tokenIds.AddRange(fragmentTokens.InputIds.Span);
38+
tokenWeights.AddRange(Enumerable.Repeat(fragment.Weight, fragmentTokens.InputIds.Span.Length));
3939
}
4040

4141
tokenIds.Add(tokenizer.EOS); //eos
@@ -64,9 +64,10 @@ public static async Task<TextEncoderResult> EncodePromptAsync(CLIPTextModel text
6464
}
6565
else
6666
{
67-
var bos = inputTokens.InputIds[0];
68-
var eos = inputTokens.InputIds[^1];
69-
var tokenIds = inputTokens.InputIds[1..^1].Pad(textEncoder.PadTokenId, minimumLength);
67+
68+
var bos = inputTokens.InputIds.Span[0];
69+
var eos = inputTokens.InputIds.Span[^1];
70+
var tokenIds = inputTokens.InputIds.Span[1..^1].Pad(textEncoder.PadTokenId, minimumLength);
7071

7172
// Create batches, 75 tokens + EOS & BOS
7273
var chunkSize = textEncoder.SequenceLength - 2;
@@ -149,7 +150,7 @@ public static void ApplyPromptWeights(TokenizerResult tokenizerOutput, TextEncod
149150
var hiddenStates = encoderOutput.HiddenStates;
150151
var numTokens = hiddenStates.Dimensions[1];
151152
var embedDim = hiddenStates.Dimensions[2];
152-
var weights = tokenizerOutput.Weights.Pad(1, numTokens);
153+
var weights = tokenizerOutput.Weights.Span.Pad(1, numTokens);
153154
if (weights.All(x => x == 1))
154155
return;
155156

TensorStack.StableDiffusion/Models/CLIPTextModel.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using TensorStack.Common.Tensor;
88
using TensorStack.StableDiffusion.Common;
99
using TensorStack.StableDiffusion.Config;
10+
using TensorStack.TextGeneration.Tokenizers;
1011

1112
namespace TensorStack.StableDiffusion.Models
1213
{
@@ -54,8 +55,8 @@ public virtual async Task<TextEncoderResult> RunAsync(TokenizerResult tokenInput
5455

5556
var paddedInput = PadOrTruncate(tokenInput);
5657
var supportsAttentionMask = Metadata.Inputs.Count == 2;
57-
var inputTensor = new TensorSpan<long>(paddedInput.InputIds, [1, SequenceLength]);
58-
var attentionTensor = new TensorSpan<long>(paddedInput.AttentionMask, [1, SequenceLength]);
58+
var inputTensor = paddedInput.InputIds.AsTensorSpan();
59+
var attentionTensor = paddedInput.Mask.AsTensorSpan();
5960
using (var modelParameters = new ModelParameters(Metadata, cancellationToken))
6061
{
6162
// Inputs
@@ -87,9 +88,9 @@ public virtual async Task<TextEncoderResult> RunAsync(TokenizerResult tokenInput
8788
/// <returns>TokenizerResult.</returns>
8889
protected TokenizerResult PadOrTruncate(TokenizerResult tokenizerResult)
8990
{
90-
var inputIds = tokenizerResult.InputIds.PadOrTruncate(PadTokenId, SequenceLength);
91-
var attentionMask = tokenizerResult.AttentionMask.PadOrTruncate(0, SequenceLength);
92-
var weights = tokenizerResult.Weights.PadOrTruncate(1, SequenceLength);
91+
var inputIds = tokenizerResult.InputIds.Span.PadOrTruncate(PadTokenId, SequenceLength);
92+
var attentionMask = tokenizerResult.Mask.Span.PadOrTruncate(0, SequenceLength);
93+
var weights = tokenizerResult.Weights is null ? default : tokenizerResult.Weights.Span.PadOrTruncate(1, SequenceLength);
9394
return new TokenizerResult(inputIds, attentionMask, weights);
9495
}
9596

TensorStack.StableDiffusion/Models/CLIPTextModelWithProjection.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using TensorStack.Common.Tensor;
88
using TensorStack.StableDiffusion.Common;
99
using TensorStack.StableDiffusion.Config;
10+
using TensorStack.TextGeneration.Tokenizers;
1011

1112
namespace TensorStack.StableDiffusion.Models
1213
{
@@ -36,8 +37,8 @@ public override async Task<TextEncoderResult> RunAsync(TokenizerResult tokenInpu
3637
var paddedInput = PadOrTruncate(tokenInput);
3738
var hiddenStateCount = Metadata.Outputs.Count - 1;
3839
var supportsAttentionMask = Metadata.Inputs.Count == 2;
39-
var inputTensor = new TensorSpan<long>(paddedInput.InputIds, [1, SequenceLength]);
40-
var attentionTensor = new TensorSpan<long>(paddedInput.AttentionMask, [1, SequenceLength]);
40+
var inputTensor = paddedInput.InputIds;
41+
var attentionTensor = paddedInput.Mask;
4142
using (var modelParameters = new ModelParameters(Metadata, cancellationToken))
4243
{
4344
// Inputs

TensorStack.StableDiffusion/Models/T5EncoderModel.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using TensorStack.Common.Tensor;
77
using TensorStack.StableDiffusion.Common;
88
using TensorStack.StableDiffusion.Config;
9+
using TensorStack.TextGeneration.Tokenizers;
910

1011
namespace TensorStack.StableDiffusion.Models
1112
{
@@ -41,8 +42,8 @@ public override async Task<TextEncoderResult> RunAsync(TokenizerResult tokenInpu
4142

4243
var sequenceLength = tokenInput.InputIds.Length;
4344
var supportsAttentionMask = Metadata.Outputs.Count == 2;
44-
var inputTensor = new TensorSpan<long>(tokenInput.InputIds, [1, sequenceLength]);
45-
var attentionTensor = new TensorSpan<long>(tokenInput.AttentionMask, [1, sequenceLength]);
45+
var inputTensor = tokenInput.InputIds;
46+
var attentionTensor = tokenInput.Mask;
4647
using (var modelParameters = new ModelParameters(Metadata, cancellationToken))
4748
{
4849
// Inputs
@@ -51,7 +52,7 @@ public override async Task<TextEncoderResult> RunAsync(TokenizerResult tokenInpu
5152
modelParameters.AddInput(attentionTensor);
5253

5354
// Outputs
54-
modelParameters.AddOutput([1, sequenceLength, HiddenSize]);
55+
modelParameters.AddOutput([1, (int)sequenceLength, HiddenSize]);
5556

5657
// Inference
5758
using (var results = await RunInferenceAsync(modelParameters))

TensorStack.StableDiffusion/Pipelines/Flux/FluxBase.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
using TensorStack.StableDiffusion.Helpers;
1414
using TensorStack.StableDiffusion.Models;
1515
using TensorStack.StableDiffusion.Schedulers;
16-
using TensorStack.StableDiffusion.Tokenizers;
16+
using TensorStack.TextGeneration.Tokenizers;
1717

1818
namespace TensorStack.StableDiffusion.Pipelines.Flux
1919
{
@@ -151,7 +151,7 @@ protected async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, C
151151
// Tokenize2
152152
var promptTokens = await TokenizePromptAsync(options.Prompt, cancellationToken);
153153
var negativePromptTokens = await TokenizePromptAsync(options.NegativePrompt, cancellationToken);
154-
var maxTokenLength = Math.Max(promptTokens.InputIds.Length, negativePromptTokens.InputIds.Length);
154+
var maxTokenLength = (int)Math.Max(promptTokens.InputIds.Length, negativePromptTokens.InputIds.Length);
155155

156156
// Tokenizer2
157157
var prompt2Tokens = await TokenizePrompt2Async(options.Prompt, cancellationToken);

TensorStack.StableDiffusion/Pipelines/Flux/FluxConfig.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using TensorStack.Common;
55
using TensorStack.StableDiffusion.Config;
66
using TensorStack.StableDiffusion.Enums;
7+
using TensorStack.TextGeneration.Tokenizers;
78

89
namespace TensorStack.StableDiffusion.Pipelines.Flux
910
{
@@ -15,7 +16,7 @@ public record FluxConfig : PipelineConfig
1516
public FluxConfig()
1617
{
1718
Tokenizer = new TokenizerConfig();
18-
Tokenizer2 = new TokenizerConfig();
19+
Tokenizer2 = new TokenizerConfig{MaxLength = 512 };
1920
TextEncoder = new CLIPModelConfig();
2021
TextEncoder2 = new CLIPModelConfig
2122
{
@@ -54,8 +55,6 @@ public FluxConfig()
5455
/// <param name="executionProvider">The execution provider.</param>
5556
public override void SetProvider(ExecutionProvider executionProvider)
5657
{
57-
Tokenizer.SetProvider(executionProvider);
58-
Tokenizer2.SetProvider(executionProvider);
5958
TextEncoder.SetProvider(executionProvider);
6059
TextEncoder2.SetProvider(executionProvider);
6160
Transformer.SetProvider(executionProvider);

0 commit comments

Comments
 (0)