2121BOS_TOKEN_ID = 49406 # Start of text
2222EOS_TOKEN_ID = 49407 # End of text
2323MAX_TOKENS_PER_CHUNK = 75 # 77 - 2 (BOS + EOS)
24+ CHUNK_SIZE = 77 # BOS + 75 content tokens + EOS
2425
2526# T5-XXL embedding dimension
2627T5_EMBEDDING_DIM = 4096
2728
29+
30+ def _create_empty_chunk () -> tuple [list [int ], list [float ]]:
31+ """Create an empty 77-token chunk with BOS/EOS structure.
32+
33+ Returns:
34+ Tuple of (chunk, weights) where chunk is [BOS] + 76*[EOS]
35+ and weights is all 1.0s.
36+ """
37+ chunk = [BOS_TOKEN_ID ] + [EOS_TOKEN_ID ] * (MAX_TOKENS_PER_CHUNK + 1 )
38+ weights = [1.0 ] * CHUNK_SIZE
39+ return chunk , weights
40+
41+
42+ def _apply_weights_to_embedding (
43+ embedding : torch .Tensor ,
44+ weights : torch .Tensor ,
45+ ) -> torch .Tensor :
46+ """Apply per-token weights to embedding tensor in-place.
47+
48+ Uses vectorized multiplication to efficiently scale each token's
49+ embedding vector by its corresponding weight.
50+
51+ Args:
52+ embedding: Token embeddings of shape (num_tokens, hidden_dim)
53+ weights: Weight values of shape (num_weights,)
54+
55+ Returns:
56+ The modified embedding tensor (modified in-place)
57+ """
58+ num_tokens = min (len (weights ), embedding .size (0 ))
59+ embedding [:num_tokens ] = embedding [:num_tokens ] * weights [:num_tokens ].unsqueeze (- 1 )
60+ return embedding
61+
62+
2863# Regex for parsing A1111-style attention weights
2964RE_ATTENTION = re .compile (
3065 r"""
@@ -241,8 +276,7 @@ def group_tokens_into_chunks(
241276 # Handle edge case: if no chunks were produced (e.g., token_ids was empty or only BREAK markers),
242277 # create an empty chunk to ensure encoders have at least one chunk to process
243278 if not new_token_ids :
244- empty_chunk = [BOS_TOKEN_ID ] + [EOS_TOKEN_ID ] * (MAX_TOKENS_PER_CHUNK + 1 )
245- empty_weights = [1.0 ] * 77
279+ empty_chunk , empty_weights = _create_empty_chunk ()
246280 new_token_ids .append (empty_chunk )
247281 new_weights .append (empty_weights )
248282
@@ -304,8 +338,7 @@ def pad_chunks_to_same_count(
304338
305339 # Create an empty chunk (all EOS tokens with weight 1.0)
306340 # Chunk structure: [BOS] + 75 EOS tokens + [EOS] = 77 tokens
307- empty_chunk = [BOS_TOKEN_ID ] + [EOS_TOKEN_ID ] * (MAX_TOKENS_PER_CHUNK + 1 )
308- empty_weights = [1.0 ] * 77
341+ empty_chunk , empty_weights = _create_empty_chunk ()
309342
310343 if len_a > len_b :
311344 for _ in range (len_a - len_b ):
@@ -355,17 +388,8 @@ def get_weighted_text_embeddings_sd15(
355388 neg_tokens , neg_weights , pad_last_block = pad_last_block
356389 )
357390
358- # Handle edge case where chunking produces empty lists (e.g., prompt of only BREAK keywords)
359- # Create an empty chunk to ensure encoders have at least one chunk to process
360- empty_chunk = [BOS_TOKEN_ID ] + [EOS_TOKEN_ID ] * (MAX_TOKENS_PER_CHUNK + 1 )
361- empty_weights = [1.0 ] * 77
362-
363- if not prompt_chunks :
364- prompt_chunks = [empty_chunk ]
365- prompt_chunk_weights = [empty_weights ]
366- if not neg_chunks :
367- neg_chunks = [empty_chunk ]
368- neg_chunk_weights = [empty_weights ]
391+ # Note: group_tokens_into_chunks guarantees at least one chunk,
392+ # so no need to handle empty chunk lists here
369393
370394 # Ensure same number of chunks (in case of different BREAK marker counts)
371395 prompt_chunks , prompt_chunk_weights , neg_chunks , neg_chunk_weights = pad_chunks_to_same_count (
@@ -391,11 +415,8 @@ def get_weighted_text_embeddings_sd15(
391415 with torch .no_grad ():
392416 token_embedding = pipe .text_encoder (token_tensor )[0 ].squeeze (0 )
393417
394- # Apply weights using vectorized multiplication
395- num_tokens = min (len (weight_tensor ), token_embedding .size (0 ))
396- token_embedding [:num_tokens ] = token_embedding [:num_tokens ] * weight_tensor [
397- :num_tokens
398- ].unsqueeze (- 1 )
418+ # Apply weights
419+ _apply_weights_to_embedding (token_embedding , weight_tensor )
399420
400421 embeds .append (token_embedding .unsqueeze (0 ))
401422
@@ -406,11 +427,8 @@ def get_weighted_text_embeddings_sd15(
406427 with torch .no_grad ():
407428 neg_token_embedding = pipe .text_encoder (neg_token_tensor )[0 ].squeeze (0 )
408429
409- # Apply weights using vectorized multiplication
410- num_neg_tokens = min (len (neg_weight_tensor ), neg_token_embedding .size (0 ))
411- neg_token_embedding [:num_neg_tokens ] = neg_token_embedding [
412- :num_neg_tokens
413- ] * neg_weight_tensor [:num_neg_tokens ].unsqueeze (- 1 )
430+ # Apply weights
431+ _apply_weights_to_embedding (neg_token_embedding , neg_weight_tensor )
414432
415433 neg_embeds .append (neg_token_embedding .unsqueeze (0 ))
416434
@@ -516,23 +534,8 @@ def get_weighted_text_embeddings_sdxl(
516534 neg_tokens_2 , neg_weights_2 , pad_last_block = pad_last_block
517535 )
518536
519- # Handle edge case where chunking produces empty lists (e.g., prompt of only BREAK keywords)
520- # Create an empty chunk to ensure encoders have at least one chunk to process
521- empty_chunk = [BOS_TOKEN_ID ] + [EOS_TOKEN_ID ] * (MAX_TOKENS_PER_CHUNK + 1 )
522- empty_weights = [1.0 ] * 77
523-
524- if not prompt_chunks_1 :
525- prompt_chunks_1 = [empty_chunk ]
526- prompt_chunk_weights_1 = [empty_weights ]
527- if not neg_chunks_1 :
528- neg_chunks_1 = [empty_chunk ]
529- neg_chunk_weights_1 = [empty_weights ]
530- if not prompt_chunks_2 :
531- prompt_chunks_2 = [empty_chunk ]
532- prompt_chunk_weights_2 = [empty_weights ]
533- if not neg_chunks_2 :
534- neg_chunks_2 = [empty_chunk ]
535- neg_chunk_weights_2 = [empty_weights ]
537+ # Note: group_tokens_into_chunks guarantees at least one chunk,
538+ # so no need to handle empty chunk lists here
536539
537540 # Ensure same number of chunks for each encoder (in case of different BREAK marker counts)
538541 prompt_chunks_1 , prompt_chunk_weights_1 , neg_chunks_1 , neg_chunk_weights_1 = (
@@ -579,11 +582,8 @@ def get_weighted_text_embeddings_sdxl(
579582 [prompt_embeds_1_hidden , prompt_embeds_2_hidden ], dim = - 1
580583 ).squeeze (0 )
581584
582- # Apply weights using vectorized multiplication
583- num_tokens = min (len (weight_tensor ), token_embedding .size (0 ))
584- token_embedding [:num_tokens ] = token_embedding [:num_tokens ] * weight_tensor [
585- :num_tokens
586- ].unsqueeze (- 1 )
585+ # Apply weights
586+ _apply_weights_to_embedding (token_embedding , weight_tensor )
587587
588588 embeds .append (token_embedding .unsqueeze (0 ))
589589
@@ -609,11 +609,8 @@ def get_weighted_text_embeddings_sdxl(
609609 [neg_prompt_embeds_1_hidden , neg_prompt_embeds_2_hidden ], dim = - 1
610610 ).squeeze (0 )
611611
612- # Apply weights using vectorized multiplication
613- num_neg_tokens = min (len (neg_weight_tensor ), neg_token_embedding .size (0 ))
614- neg_token_embedding [:num_neg_tokens ] = neg_token_embedding [
615- :num_neg_tokens
616- ] * neg_weight_tensor [:num_neg_tokens ].unsqueeze (- 1 )
612+ # Apply weights
613+ _apply_weights_to_embedding (neg_token_embedding , neg_weight_tensor )
617614
618615 neg_embeds .append (neg_token_embedding .unsqueeze (0 ))
619616
@@ -673,17 +670,9 @@ def get_weighted_text_embeddings_flux(
673670
674671 # Get CLIP tokens for pooled embeddings (uses chunking for long prompts)
675672 clip_tokens , clip_weights = get_tokens_and_weights (pipe .tokenizer , effective_prompt )
676- clip_chunks , clip_chunk_weights = group_tokens_into_chunks (
677- clip_tokens , clip_weights , pad_last_block = True
678- )
679-
680- # Handle edge case: if no chunks were produced (e.g., prompt was only BREAK keywords),
681- # create an empty chunk to ensure we have at least one embedding
682- if not clip_chunks :
683- empty_chunk = [BOS_TOKEN_ID ] + [EOS_TOKEN_ID ] * (MAX_TOKENS_PER_CHUNK + 1 )
684- clip_chunks = [empty_chunk ]
685- # Note: clip_chunk_weights is not used for pooled embeddings but we maintain
686- # the variable for consistency with the chunking API
673+ # Note: group_tokens_into_chunks guarantees at least one chunk, and
674+ # clip_chunk_weights is not used for pooled embeddings
675+ clip_chunks , _ = group_tokens_into_chunks (clip_tokens , clip_weights , pad_last_block = True )
687676
688677 # Get T5 tokens for main embeddings (no chunking needed, T5 handles long sequences)
689678 t5_tokens , t5_weights = get_t5_tokens_and_weights (pipe .tokenizer_2 , effective_prompt )
@@ -706,12 +695,9 @@ def get_weighted_text_embeddings_flux(
706695 t5_output = pipe .text_encoder_2 (t5_token_tensor )
707696 t5_embeds = t5_output [0 ].squeeze (0 )
708697
709- # Apply weights to T5 embeddings using vectorized multiplication
698+ # Apply weights to T5 embeddings
710699 t5_weight_tensor = torch .tensor (t5_weights , dtype = t5_embeds .dtype , device = device )
711- num_t5_tokens = min (len (t5_weight_tensor ), t5_embeds .size (0 ))
712- t5_embeds [:num_t5_tokens ] = t5_embeds [:num_t5_tokens ] * t5_weight_tensor [
713- :num_t5_tokens
714- ].unsqueeze (- 1 )
700+ _apply_weights_to_embedding (t5_embeds , t5_weight_tensor )
715701
716702 prompt_embeds = t5_embeds .unsqueeze (0 )
717703 prompt_embeds = prompt_embeds .to (dtype = pipe .text_encoder_2 .dtype , device = device )
@@ -781,23 +767,8 @@ def get_weighted_text_embeddings_sd3(
781767 neg_tokens_2 , neg_weights_2 , pad_last_block = pad_last_block
782768 )
783769
784- # Handle edge case where chunking produces empty lists (e.g., prompt of only BREAK keywords)
785- # Create an empty chunk to ensure encoders have at least one chunk to process
786- empty_chunk = [BOS_TOKEN_ID ] + [EOS_TOKEN_ID ] * (MAX_TOKENS_PER_CHUNK + 1 )
787- empty_weights = [1.0 ] * 77
788-
789- if not prompt_chunks_1 :
790- prompt_chunks_1 = [empty_chunk ]
791- prompt_chunk_weights_1 = [empty_weights ]
792- if not neg_chunks_1 :
793- neg_chunks_1 = [empty_chunk ]
794- neg_chunk_weights_1 = [empty_weights ]
795- if not prompt_chunks_2 :
796- prompt_chunks_2 = [empty_chunk ]
797- prompt_chunk_weights_2 = [empty_weights ]
798- if not neg_chunks_2 :
799- neg_chunks_2 = [empty_chunk ]
800- neg_chunk_weights_2 = [empty_weights ]
770+ # Note: group_tokens_into_chunks guarantees at least one chunk,
771+ # so no need to handle empty chunk lists here
801772
802773 # Ensure same number of chunks for each encoder (in case of different BREAK marker counts)
803774 prompt_chunks_1 , prompt_chunk_weights_1 , neg_chunks_1 , neg_chunk_weights_1 = (
@@ -850,11 +821,8 @@ def get_weighted_text_embeddings_sd3(
850821 [prompt_embeds_1_hidden , prompt_embeds_2_hidden ], dim = - 1
851822 ).squeeze (0 )
852823
853- # Apply weights using vectorized multiplication
854- num_tokens = min (len (weight_tensor ), token_embedding .size (0 ))
855- token_embedding [:num_tokens ] = token_embedding [:num_tokens ] * weight_tensor [
856- :num_tokens
857- ].unsqueeze (- 1 )
824+ # Apply weights
825+ _apply_weights_to_embedding (token_embedding , weight_tensor )
858826
859827 embeds .append (token_embedding .unsqueeze (0 ))
860828
@@ -880,11 +848,8 @@ def get_weighted_text_embeddings_sd3(
880848 [neg_prompt_embeds_1_hidden , neg_prompt_embeds_2_hidden ], dim = - 1
881849 ).squeeze (0 )
882850
883- # Apply weights using vectorized multiplication
884- num_neg_tokens = min (len (neg_weight_tensor ), neg_token_embedding .size (0 ))
885- neg_token_embedding [:num_neg_tokens ] = neg_token_embedding [
886- :num_neg_tokens
887- ] * neg_weight_tensor [:num_neg_tokens ].unsqueeze (- 1 )
851+ # Apply weights
852+ _apply_weights_to_embedding (neg_token_embedding , neg_weight_tensor )
888853
889854 neg_embeds .append (neg_token_embedding .unsqueeze (0 ))
890855
@@ -903,24 +868,18 @@ def get_weighted_text_embeddings_sd3(
903868 t5_token_tensor = torch .tensor ([prompt_tokens_3 ], dtype = torch .long , device = device )
904869 t5_embeds = pipe .text_encoder_3 (t5_token_tensor )[0 ].squeeze (0 )
905870
906- # Apply weights using vectorized multiplication
871+ # Apply weights
907872 t5_weight_tensor = torch .tensor (prompt_weights_3 , dtype = t5_embeds .dtype , device = device )
908- num_t5_tokens = min (len (t5_weight_tensor ), t5_embeds .size (0 ))
909- t5_embeds [:num_t5_tokens ] = t5_embeds [:num_t5_tokens ] * t5_weight_tensor [
910- :num_t5_tokens
911- ].unsqueeze (- 1 )
873+ _apply_weights_to_embedding (t5_embeds , t5_weight_tensor )
912874 t5_embeds = t5_embeds .unsqueeze (0 )
913875
914876 # Negative T5
915877 neg_t5_token_tensor = torch .tensor ([neg_tokens_3 ], dtype = torch .long , device = device )
916878 neg_t5_embeds = pipe .text_encoder_3 (neg_t5_token_tensor )[0 ].squeeze (0 )
917879
918- # Apply weights using vectorized multiplication
880+ # Apply weights
919881 neg_t5_weight_tensor = torch .tensor (neg_weights_3 , dtype = neg_t5_embeds .dtype , device = device )
920- num_neg_t5_tokens = min (len (neg_t5_weight_tensor ), neg_t5_embeds .size (0 ))
921- neg_t5_embeds [:num_neg_t5_tokens ] = neg_t5_embeds [
922- :num_neg_t5_tokens
923- ] * neg_t5_weight_tensor [:num_neg_t5_tokens ].unsqueeze (- 1 )
882+ _apply_weights_to_embedding (neg_t5_embeds , neg_t5_weight_tensor )
924883 neg_t5_embeds = neg_t5_embeds .unsqueeze (0 )
925884 else :
926885 # Create zero tensors if T5 not available
0 commit comments