Skip to content

Commit cb00179

Browse files
authored
Merge pull request #54 from jkoelker/feat/issue-53-long-prompt-support
feat: add long prompt support (>77 tokens) for CLIP-based pipelines
2 parents 587071f + 549fa3e commit cb00179

File tree

4 files changed

+2386
-11
lines changed

4 files changed

+2386
-11
lines changed

src/oneiro/pipelines/civitai_checkpoint.py

Lines changed: 134 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414

1515
from oneiro.pipelines.base import BasePipeline, GenerationResult
1616
from oneiro.pipelines.embedding import EmbeddingLoaderMixin, parse_embeddings_from_config
17+
from oneiro.pipelines.long_prompt import (
18+
get_weighted_text_embeddings_flux,
19+
get_weighted_text_embeddings_sd3,
20+
get_weighted_text_embeddings_sd15,
21+
get_weighted_text_embeddings_sdxl,
22+
)
1723
from oneiro.pipelines.lora import LoraLoaderMixin, parse_loras_from_model_config
1824

1925
if TYPE_CHECKING:
@@ -777,15 +783,20 @@ def generate(
777783

778784
# Build generation kwargs
779785
gen_kwargs: dict[str, Any] = {
780-
"prompt": prompt,
781786
"num_inference_steps": steps,
782787
"guidance_scale": guidance_scale,
783788
"generator": generator,
784789
}
785790

786-
# Add negative prompt only for pipelines that support it
787-
if self._pipeline_config.supports_negative_prompt and negative_prompt:
788-
gen_kwargs["negative_prompt"] = negative_prompt
791+
# Use embedding-based prompt handling for pipelines that support it
792+
# (SD 1.x, SD 2.x, SDXL, Flux, SD3) - enables weight syntax like (word:1.5)
793+
if self._supports_prompt_embeddings():
794+
gen_kwargs.update(self._encode_prompts_to_embeddings(prompt, negative_prompt))
795+
else:
796+
# Fallback for unsupported pipelines
797+
gen_kwargs["prompt"] = prompt
798+
if self._pipeline_config.supports_negative_prompt and negative_prompt:
799+
gen_kwargs["negative_prompt"] = negative_prompt
789800

790801
if init_image:
791802
print(f"CivitAI img2img: '{prompt[:50]}...' seed={actual_seed} strength={strength}")
@@ -822,3 +833,122 @@ def pipeline_config(self) -> PipelineConfig | None:
822833
def detected_base_model(self) -> str | None:
823834
"""Get the detected/configured base model."""
824835
return self._base_model
836+
837+
def _supports_prompt_embeddings(self) -> bool:
838+
"""Check if this pipeline supports embedding-based prompt handling.
839+
840+
Returns True for pipelines that support pre-computed embeddings with
841+
weight handling (CLIP-based: SD 1.x, SD 2.x, SDXL; flow-based: Flux;
842+
and MMDiT-based: SD3).
843+
844+
Returns:
845+
True if the pipeline supports prompt embeddings
846+
"""
847+
if self.pipe is None or self._pipeline_config is None:
848+
return False
849+
850+
# Pipelines that support embedding-based prompt handling
851+
pipeline_class = self._pipeline_config.pipeline_class
852+
supported_pipelines = {
853+
"StableDiffusionPipeline",
854+
"StableDiffusionXLPipeline",
855+
"FluxPipeline",
856+
"StableDiffusion3Pipeline",
857+
}
858+
859+
return pipeline_class in supported_pipelines
860+
861+
def _encode_prompts_to_embeddings(
862+
self,
863+
prompt: str,
864+
negative_prompt: str | None,
865+
) -> dict[str, Any]:
866+
"""Encode prompts to embeddings with weight and chunking support.
867+
868+
Converts text prompts to pre-computed embeddings, supporting:
869+
- A1111-style weight syntax like (word:1.5) and [word]
870+
- Prompts longer than CLIP's 77-token limit via chunking
871+
- BREAK keyword for forcing chunk boundaries
872+
873+
This method handles all supported pipelines (SD 1.x/2.x, SDXL, Flux, SD3)
874+
with appropriate embedding generation for each architecture.
875+
876+
Args:
877+
prompt: The positive prompt
878+
negative_prompt: The negative prompt (may be None)
879+
880+
Returns:
881+
Dict of embedding kwargs to pass to the pipeline
882+
"""
883+
if self.pipe is None or self._pipeline_config is None:
884+
return {}
885+
886+
neg_prompt = negative_prompt or ""
887+
pipeline_class = self._pipeline_config.pipeline_class
888+
result: dict[str, Any] = {}
889+
890+
if pipeline_class == "FluxPipeline":
891+
# Flux uses T5 for main embeddings + CLIP for pooled
892+
# Note: Flux does not support negative prompts
893+
prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux(
894+
self.pipe,
895+
prompt=prompt,
896+
)
897+
898+
result["prompt_embeds"] = prompt_embeds
899+
result["pooled_prompt_embeds"] = pooled_prompt_embeds
900+
901+
elif pipeline_class == "StableDiffusion3Pipeline":
902+
# SD3 uses dual CLIP + T5 encoders
903+
(
904+
prompt_embeds,
905+
negative_prompt_embeds,
906+
pooled_prompt_embeds,
907+
negative_pooled_prompt_embeds,
908+
) = get_weighted_text_embeddings_sd3(
909+
self.pipe,
910+
prompt=prompt,
911+
negative_prompt=neg_prompt,
912+
)
913+
914+
result["prompt_embeds"] = prompt_embeds
915+
result["pooled_prompt_embeds"] = pooled_prompt_embeds
916+
917+
if self._pipeline_config.supports_negative_prompt:
918+
result["negative_prompt_embeds"] = negative_prompt_embeds
919+
result["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
920+
921+
elif pipeline_class == "StableDiffusionXLPipeline":
922+
# SDXL uses dual text encoders
923+
(
924+
prompt_embeds,
925+
negative_prompt_embeds,
926+
pooled_prompt_embeds,
927+
negative_pooled_prompt_embeds,
928+
) = get_weighted_text_embeddings_sdxl(
929+
self.pipe,
930+
prompt=prompt,
931+
negative_prompt=neg_prompt,
932+
)
933+
934+
result["prompt_embeds"] = prompt_embeds
935+
result["pooled_prompt_embeds"] = pooled_prompt_embeds
936+
937+
if self._pipeline_config.supports_negative_prompt:
938+
result["negative_prompt_embeds"] = negative_prompt_embeds
939+
result["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds
940+
941+
else:
942+
# SD 1.x / 2.x use single text encoder
943+
prompt_embeds, negative_prompt_embeds = get_weighted_text_embeddings_sd15(
944+
self.pipe,
945+
prompt=prompt,
946+
negative_prompt=neg_prompt,
947+
)
948+
949+
result["prompt_embeds"] = prompt_embeds
950+
951+
if self._pipeline_config.supports_negative_prompt:
952+
result["negative_prompt_embeds"] = negative_prompt_embeds
953+
954+
return result

0 commit comments

Comments
 (0)