99using TensorStack . Common . Tensor ;
1010using TensorStack . StableDiffusion . Common ;
1111using TensorStack . StableDiffusion . Models ;
12- using TensorStack . StableDiffusion . Tokenizers ;
12+ using TensorStack . TextGeneration . Tokenizers ;
1313
1414namespace TensorStack . StableDiffusion . Helpers
1515{
@@ -34,8 +34,8 @@ public static async Task<TokenizerResult> TokenizePromptAsync(CLIPTokenizer toke
3434 foreach ( var fragment in fragments )
3535 {
3636 var fragmentTokens = await tokenizer . EncodeAsync ( fragment . Text , false ) ;
37- tokenIds . AddRange ( fragmentTokens . InputIds ) ;
38- tokenWeights . AddRange ( Enumerable . Repeat ( fragment . Weight , fragmentTokens . InputIds . Length ) ) ;
37+ tokenIds . AddRange ( fragmentTokens . InputIds . Span ) ;
38+ tokenWeights . AddRange ( Enumerable . Repeat ( fragment . Weight , fragmentTokens . InputIds . Span . Length ) ) ;
3939 }
4040
4141 tokenIds . Add ( tokenizer . EOS ) ; //eos
@@ -64,9 +64,10 @@ public static async Task<TextEncoderResult> EncodePromptAsync(CLIPTextModel text
6464 }
6565 else
6666 {
67- var bos = inputTokens . InputIds [ 0 ] ;
68- var eos = inputTokens . InputIds [ ^ 1 ] ;
69- var tokenIds = inputTokens . InputIds [ 1 ..^ 1 ] . Pad ( textEncoder . PadTokenId , minimumLength ) ;
67+
68+ var bos = inputTokens . InputIds . Span [ 0 ] ;
69+ var eos = inputTokens . InputIds . Span [ ^ 1 ] ;
70+ var tokenIds = inputTokens . InputIds . Span [ 1 ..^ 1 ] . Pad ( textEncoder . PadTokenId , minimumLength ) ;
7071
7172 // Create batches, 75 tokens + EOS & BOS
7273 var chunkSize = textEncoder . SequenceLength - 2 ;
@@ -149,7 +150,7 @@ public static void ApplyPromptWeights(TokenizerResult tokenizerOutput, TextEncod
149150 var hiddenStates = encoderOutput . HiddenStates ;
150151 var numTokens = hiddenStates . Dimensions [ 1 ] ;
151152 var embedDim = hiddenStates . Dimensions [ 2 ] ;
152- var weights = tokenizerOutput . Weights . Pad ( 1 , numTokens ) ;
153+ var weights = tokenizerOutput . Weights . Span . Pad ( 1 , numTokens ) ;
153154 if ( weights . All ( x => x == 1 ) )
154155 return ;
155156
0 commit comments