Skip to content

Commit 549fa3e

Browse files
refactor(long_prompt): add helper functions and constants for DRY code
- Add CHUNK_SIZE constant (77) to replace magic numbers - Add _create_empty_chunk() helper for consistent empty chunk creation - Add _apply_weights_to_embedding() helper for vectorized weight application - Remove redundant empty chunk checks (group_tokens_into_chunks guarantees chunks) - Use _ for unused clip_chunk_weights in Flux pooled embedding code Addresses efficiency improvements from PR review feedback.
1 parent 5084074 commit 549fa3e

File tree

1 file changed

+64
-105
lines changed

1 file changed

+64
-105
lines changed

src/oneiro/pipelines/long_prompt.py

Lines changed: 64 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,45 @@
2121
BOS_TOKEN_ID = 49406 # Start of text
2222
EOS_TOKEN_ID = 49407 # End of text
2323
MAX_TOKENS_PER_CHUNK = 75 # 77 - 2 (BOS + EOS)
24+
CHUNK_SIZE = 77 # BOS + 75 content tokens + EOS
2425

2526
# T5-XXL embedding dimension
2627
T5_EMBEDDING_DIM = 4096
2728

29+
30+
def _create_empty_chunk() -> tuple[list[int], list[float]]:
31+
"""Create an empty 77-token chunk with BOS/EOS structure.
32+
33+
Returns:
34+
Tuple of (chunk, weights) where chunk is [BOS] + 76*[EOS]
35+
and weights is all 1.0s.
36+
"""
37+
chunk = [BOS_TOKEN_ID] + [EOS_TOKEN_ID] * (MAX_TOKENS_PER_CHUNK + 1)
38+
weights = [1.0] * CHUNK_SIZE
39+
return chunk, weights
40+
41+
42+
def _apply_weights_to_embedding(
43+
embedding: torch.Tensor,
44+
weights: torch.Tensor,
45+
) -> torch.Tensor:
46+
"""Apply per-token weights to embedding tensor in-place.
47+
48+
Uses vectorized multiplication to efficiently scale each token's
49+
embedding vector by its corresponding weight.
50+
51+
Args:
52+
embedding: Token embeddings of shape (num_tokens, hidden_dim)
53+
weights: Weight values of shape (num_weights,)
54+
55+
Returns:
56+
The modified embedding tensor (modified in-place)
57+
"""
58+
num_tokens = min(len(weights), embedding.size(0))
59+
embedding[:num_tokens] = embedding[:num_tokens] * weights[:num_tokens].unsqueeze(-1)
60+
return embedding
61+
62+
2863
# Regex for parsing A1111-style attention weights
2964
RE_ATTENTION = re.compile(
3065
r"""
@@ -241,8 +276,7 @@ def group_tokens_into_chunks(
241276
# Handle edge case: if no chunks were produced (e.g., token_ids was empty or only BREAK markers),
242277
# create an empty chunk to ensure encoders have at least one chunk to process
243278
if not new_token_ids:
244-
empty_chunk = [BOS_TOKEN_ID] + [EOS_TOKEN_ID] * (MAX_TOKENS_PER_CHUNK + 1)
245-
empty_weights = [1.0] * 77
279+
empty_chunk, empty_weights = _create_empty_chunk()
246280
new_token_ids.append(empty_chunk)
247281
new_weights.append(empty_weights)
248282

@@ -304,8 +338,7 @@ def pad_chunks_to_same_count(
304338

305339
# Create an empty chunk (all EOS tokens with weight 1.0)
306340
# Chunk structure: [BOS] + 75 EOS tokens + [EOS] = 77 tokens
307-
empty_chunk = [BOS_TOKEN_ID] + [EOS_TOKEN_ID] * (MAX_TOKENS_PER_CHUNK + 1)
308-
empty_weights = [1.0] * 77
341+
empty_chunk, empty_weights = _create_empty_chunk()
309342

310343
if len_a > len_b:
311344
for _ in range(len_a - len_b):
@@ -355,17 +388,8 @@ def get_weighted_text_embeddings_sd15(
355388
neg_tokens, neg_weights, pad_last_block=pad_last_block
356389
)
357390

358-
# Handle edge case where chunking produces empty lists (e.g., prompt of only BREAK keywords)
359-
# Create an empty chunk to ensure encoders have at least one chunk to process
360-
empty_chunk = [BOS_TOKEN_ID] + [EOS_TOKEN_ID] * (MAX_TOKENS_PER_CHUNK + 1)
361-
empty_weights = [1.0] * 77
362-
363-
if not prompt_chunks:
364-
prompt_chunks = [empty_chunk]
365-
prompt_chunk_weights = [empty_weights]
366-
if not neg_chunks:
367-
neg_chunks = [empty_chunk]
368-
neg_chunk_weights = [empty_weights]
391+
# Note: group_tokens_into_chunks guarantees at least one chunk,
392+
# so no need to handle empty chunk lists here
369393

370394
# Ensure same number of chunks (in case of different BREAK marker counts)
371395
prompt_chunks, prompt_chunk_weights, neg_chunks, neg_chunk_weights = pad_chunks_to_same_count(
@@ -391,11 +415,8 @@ def get_weighted_text_embeddings_sd15(
391415
with torch.no_grad():
392416
token_embedding = pipe.text_encoder(token_tensor)[0].squeeze(0)
393417

394-
# Apply weights using vectorized multiplication
395-
num_tokens = min(len(weight_tensor), token_embedding.size(0))
396-
token_embedding[:num_tokens] = token_embedding[:num_tokens] * weight_tensor[
397-
:num_tokens
398-
].unsqueeze(-1)
418+
# Apply weights
419+
_apply_weights_to_embedding(token_embedding, weight_tensor)
399420

400421
embeds.append(token_embedding.unsqueeze(0))
401422

@@ -406,11 +427,8 @@ def get_weighted_text_embeddings_sd15(
406427
with torch.no_grad():
407428
neg_token_embedding = pipe.text_encoder(neg_token_tensor)[0].squeeze(0)
408429

409-
# Apply weights using vectorized multiplication
410-
num_neg_tokens = min(len(neg_weight_tensor), neg_token_embedding.size(0))
411-
neg_token_embedding[:num_neg_tokens] = neg_token_embedding[
412-
:num_neg_tokens
413-
] * neg_weight_tensor[:num_neg_tokens].unsqueeze(-1)
430+
# Apply weights
431+
_apply_weights_to_embedding(neg_token_embedding, neg_weight_tensor)
414432

415433
neg_embeds.append(neg_token_embedding.unsqueeze(0))
416434

@@ -516,23 +534,8 @@ def get_weighted_text_embeddings_sdxl(
516534
neg_tokens_2, neg_weights_2, pad_last_block=pad_last_block
517535
)
518536

519-
# Handle edge case where chunking produces empty lists (e.g., prompt of only BREAK keywords)
520-
# Create an empty chunk to ensure encoders have at least one chunk to process
521-
empty_chunk = [BOS_TOKEN_ID] + [EOS_TOKEN_ID] * (MAX_TOKENS_PER_CHUNK + 1)
522-
empty_weights = [1.0] * 77
523-
524-
if not prompt_chunks_1:
525-
prompt_chunks_1 = [empty_chunk]
526-
prompt_chunk_weights_1 = [empty_weights]
527-
if not neg_chunks_1:
528-
neg_chunks_1 = [empty_chunk]
529-
neg_chunk_weights_1 = [empty_weights]
530-
if not prompt_chunks_2:
531-
prompt_chunks_2 = [empty_chunk]
532-
prompt_chunk_weights_2 = [empty_weights]
533-
if not neg_chunks_2:
534-
neg_chunks_2 = [empty_chunk]
535-
neg_chunk_weights_2 = [empty_weights]
537+
# Note: group_tokens_into_chunks guarantees at least one chunk,
538+
# so no need to handle empty chunk lists here
536539

537540
# Ensure same number of chunks for each encoder (in case of different BREAK marker counts)
538541
prompt_chunks_1, prompt_chunk_weights_1, neg_chunks_1, neg_chunk_weights_1 = (
@@ -579,11 +582,8 @@ def get_weighted_text_embeddings_sdxl(
579582
[prompt_embeds_1_hidden, prompt_embeds_2_hidden], dim=-1
580583
).squeeze(0)
581584

582-
# Apply weights using vectorized multiplication
583-
num_tokens = min(len(weight_tensor), token_embedding.size(0))
584-
token_embedding[:num_tokens] = token_embedding[:num_tokens] * weight_tensor[
585-
:num_tokens
586-
].unsqueeze(-1)
585+
# Apply weights
586+
_apply_weights_to_embedding(token_embedding, weight_tensor)
587587

588588
embeds.append(token_embedding.unsqueeze(0))
589589

@@ -609,11 +609,8 @@ def get_weighted_text_embeddings_sdxl(
609609
[neg_prompt_embeds_1_hidden, neg_prompt_embeds_2_hidden], dim=-1
610610
).squeeze(0)
611611

612-
# Apply weights using vectorized multiplication
613-
num_neg_tokens = min(len(neg_weight_tensor), neg_token_embedding.size(0))
614-
neg_token_embedding[:num_neg_tokens] = neg_token_embedding[
615-
:num_neg_tokens
616-
] * neg_weight_tensor[:num_neg_tokens].unsqueeze(-1)
612+
# Apply weights
613+
_apply_weights_to_embedding(neg_token_embedding, neg_weight_tensor)
617614

618615
neg_embeds.append(neg_token_embedding.unsqueeze(0))
619616

@@ -673,17 +670,9 @@ def get_weighted_text_embeddings_flux(
673670

674671
# Get CLIP tokens for pooled embeddings (uses chunking for long prompts)
675672
clip_tokens, clip_weights = get_tokens_and_weights(pipe.tokenizer, effective_prompt)
676-
clip_chunks, clip_chunk_weights = group_tokens_into_chunks(
677-
clip_tokens, clip_weights, pad_last_block=True
678-
)
679-
680-
# Handle edge case: if no chunks were produced (e.g., prompt was only BREAK keywords),
681-
# create an empty chunk to ensure we have at least one embedding
682-
if not clip_chunks:
683-
empty_chunk = [BOS_TOKEN_ID] + [EOS_TOKEN_ID] * (MAX_TOKENS_PER_CHUNK + 1)
684-
clip_chunks = [empty_chunk]
685-
# Note: clip_chunk_weights is not used for pooled embeddings but we maintain
686-
# the variable for consistency with the chunking API
673+
# Note: group_tokens_into_chunks guarantees at least one chunk, and
674+
# clip_chunk_weights is not used for pooled embeddings
675+
clip_chunks, _ = group_tokens_into_chunks(clip_tokens, clip_weights, pad_last_block=True)
687676

688677
# Get T5 tokens for main embeddings (no chunking needed, T5 handles long sequences)
689678
t5_tokens, t5_weights = get_t5_tokens_and_weights(pipe.tokenizer_2, effective_prompt)
@@ -706,12 +695,9 @@ def get_weighted_text_embeddings_flux(
706695
t5_output = pipe.text_encoder_2(t5_token_tensor)
707696
t5_embeds = t5_output[0].squeeze(0)
708697

709-
# Apply weights to T5 embeddings using vectorized multiplication
698+
# Apply weights to T5 embeddings
710699
t5_weight_tensor = torch.tensor(t5_weights, dtype=t5_embeds.dtype, device=device)
711-
num_t5_tokens = min(len(t5_weight_tensor), t5_embeds.size(0))
712-
t5_embeds[:num_t5_tokens] = t5_embeds[:num_t5_tokens] * t5_weight_tensor[
713-
:num_t5_tokens
714-
].unsqueeze(-1)
700+
_apply_weights_to_embedding(t5_embeds, t5_weight_tensor)
715701

716702
prompt_embeds = t5_embeds.unsqueeze(0)
717703
prompt_embeds = prompt_embeds.to(dtype=pipe.text_encoder_2.dtype, device=device)
@@ -781,23 +767,8 @@ def get_weighted_text_embeddings_sd3(
781767
neg_tokens_2, neg_weights_2, pad_last_block=pad_last_block
782768
)
783769

784-
# Handle edge case where chunking produces empty lists (e.g., prompt of only BREAK keywords)
785-
# Create an empty chunk to ensure encoders have at least one chunk to process
786-
empty_chunk = [BOS_TOKEN_ID] + [EOS_TOKEN_ID] * (MAX_TOKENS_PER_CHUNK + 1)
787-
empty_weights = [1.0] * 77
788-
789-
if not prompt_chunks_1:
790-
prompt_chunks_1 = [empty_chunk]
791-
prompt_chunk_weights_1 = [empty_weights]
792-
if not neg_chunks_1:
793-
neg_chunks_1 = [empty_chunk]
794-
neg_chunk_weights_1 = [empty_weights]
795-
if not prompt_chunks_2:
796-
prompt_chunks_2 = [empty_chunk]
797-
prompt_chunk_weights_2 = [empty_weights]
798-
if not neg_chunks_2:
799-
neg_chunks_2 = [empty_chunk]
800-
neg_chunk_weights_2 = [empty_weights]
770+
# Note: group_tokens_into_chunks guarantees at least one chunk,
771+
# so no need to handle empty chunk lists here
801772

802773
# Ensure same number of chunks for each encoder (in case of different BREAK marker counts)
803774
prompt_chunks_1, prompt_chunk_weights_1, neg_chunks_1, neg_chunk_weights_1 = (
@@ -850,11 +821,8 @@ def get_weighted_text_embeddings_sd3(
850821
[prompt_embeds_1_hidden, prompt_embeds_2_hidden], dim=-1
851822
).squeeze(0)
852823

853-
# Apply weights using vectorized multiplication
854-
num_tokens = min(len(weight_tensor), token_embedding.size(0))
855-
token_embedding[:num_tokens] = token_embedding[:num_tokens] * weight_tensor[
856-
:num_tokens
857-
].unsqueeze(-1)
824+
# Apply weights
825+
_apply_weights_to_embedding(token_embedding, weight_tensor)
858826

859827
embeds.append(token_embedding.unsqueeze(0))
860828

@@ -880,11 +848,8 @@ def get_weighted_text_embeddings_sd3(
880848
[neg_prompt_embeds_1_hidden, neg_prompt_embeds_2_hidden], dim=-1
881849
).squeeze(0)
882850

883-
# Apply weights using vectorized multiplication
884-
num_neg_tokens = min(len(neg_weight_tensor), neg_token_embedding.size(0))
885-
neg_token_embedding[:num_neg_tokens] = neg_token_embedding[
886-
:num_neg_tokens
887-
] * neg_weight_tensor[:num_neg_tokens].unsqueeze(-1)
851+
# Apply weights
852+
_apply_weights_to_embedding(neg_token_embedding, neg_weight_tensor)
888853

889854
neg_embeds.append(neg_token_embedding.unsqueeze(0))
890855

@@ -903,24 +868,18 @@ def get_weighted_text_embeddings_sd3(
903868
t5_token_tensor = torch.tensor([prompt_tokens_3], dtype=torch.long, device=device)
904869
t5_embeds = pipe.text_encoder_3(t5_token_tensor)[0].squeeze(0)
905870

906-
# Apply weights using vectorized multiplication
871+
# Apply weights
907872
t5_weight_tensor = torch.tensor(prompt_weights_3, dtype=t5_embeds.dtype, device=device)
908-
num_t5_tokens = min(len(t5_weight_tensor), t5_embeds.size(0))
909-
t5_embeds[:num_t5_tokens] = t5_embeds[:num_t5_tokens] * t5_weight_tensor[
910-
:num_t5_tokens
911-
].unsqueeze(-1)
873+
_apply_weights_to_embedding(t5_embeds, t5_weight_tensor)
912874
t5_embeds = t5_embeds.unsqueeze(0)
913875

914876
# Negative T5
915877
neg_t5_token_tensor = torch.tensor([neg_tokens_3], dtype=torch.long, device=device)
916878
neg_t5_embeds = pipe.text_encoder_3(neg_t5_token_tensor)[0].squeeze(0)
917879

918-
# Apply weights using vectorized multiplication
880+
# Apply weights
919881
neg_t5_weight_tensor = torch.tensor(neg_weights_3, dtype=neg_t5_embeds.dtype, device=device)
920-
num_neg_t5_tokens = min(len(neg_t5_weight_tensor), neg_t5_embeds.size(0))
921-
neg_t5_embeds[:num_neg_t5_tokens] = neg_t5_embeds[
922-
:num_neg_t5_tokens
923-
] * neg_t5_weight_tensor[:num_neg_t5_tokens].unsqueeze(-1)
882+
_apply_weights_to_embedding(neg_t5_embeds, neg_t5_weight_tensor)
924883
neg_t5_embeds = neg_t5_embeds.unsqueeze(0)
925884
else:
926885
# Create zero tensors if T5 not available

0 commit comments

Comments
 (0)