77from diffusers .pipelines .flux .pipeline_flux import FluxPipeline
88from optimum .quanto import qfloat8
99from PIL import Image
10- from transformers import CLIPTextModel , CLIPTokenizer , T5EncoderModel , T5TokenizerFast
1110from transformers .models .auto import AutoModelForTextEncoding
1211
1312from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
14- from invokeai .app .invocations .fields import InputField , WithBoard , WithMetadata
13+ from invokeai .app .invocations .fields import (
14+ ConditioningField ,
15+ FieldDescriptions ,
16+ Input ,
17+ InputField ,
18+ WithBoard ,
19+ WithMetadata ,
20+ )
1521from invokeai .app .invocations .primitives import ImageOutput
1622from invokeai .app .services .shared .invocation_context import InvocationContext
1723from invokeai .backend .quantization .fast_quantized_diffusion_model import FastQuantizedDiffusersModel
1824from invokeai .backend .quantization .fast_quantized_transformers_model import FastQuantizedTransformersModel
19- from invokeai .backend .util . devices import TorchDevice
25+ from invokeai .backend .stable_diffusion . diffusion . conditioning_data import FLUXConditioningInfo
2026
2127TFluxModelKeys = Literal ["flux-schnell" ]
2228FLUX_MODELS : dict [TFluxModelKeys , str ] = {"flux-schnell" : "black-forest-labs/FLUX.1-schnell" }
@@ -44,7 +50,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
4450 use_8bit : bool = InputField (
4551 default = False , description = "Whether to quantize the transformer model to 8-bit precision."
4652 )
47- positive_prompt : str = InputField (description = "Positive prompt for text-to-image generation." )
53+ positive_text_conditioning : ConditioningField = InputField (
54+ description = FieldDescriptions .positive_cond , input = Input .Connection
55+ )
4856 width : int = InputField (default = 1024 , multiple_of = 16 , description = "Width of the generated image." )
4957 height : int = InputField (default = 1024 , multiple_of = 16 , description = "Height of the generated image." )
5058 num_steps : int = InputField (default = 4 , description = "Number of diffusion steps." )
@@ -58,66 +66,17 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
5866 def invoke (self , context : InvocationContext ) -> ImageOutput :
5967 model_path = context .models .download_and_cache_model (FLUX_MODELS [self .model ])
6068
61- t5_embeddings , clip_embeddings = self ._encode_prompt (context , model_path )
62- latents = self ._run_diffusion (context , model_path , clip_embeddings , t5_embeddings )
69+ # Load the conditioning data.
70+ cond_data = context .conditioning .load (self .positive_text_conditioning .conditioning_name )
71+ assert len (cond_data .conditionings ) == 1
72+ flux_conditioning = cond_data .conditionings [0 ]
73+ assert isinstance (flux_conditioning , FLUXConditioningInfo )
74+
75+ latents = self ._run_diffusion (context , model_path , flux_conditioning .clip_embeds , flux_conditioning .t5_embeds )
6376 image = self ._run_vae_decoding (context , model_path , latents )
6477 image_dto = context .images .save (image = image )
6578 return ImageOutput .build (image_dto )
6679
67- def _encode_prompt (self , context : InvocationContext , flux_model_dir : Path ) -> tuple [torch .Tensor , torch .Tensor ]:
68- # Determine the T5 max sequence length based on the model.
69- if self .model == "flux-schnell" :
70- max_seq_len = 256
71- # elif self.model == "flux-dev":
72- # max_seq_len = 512
73- else :
74- raise ValueError (f"Unknown model: { self .model } " )
75-
76- # Load the CLIP tokenizer.
77- clip_tokenizer_path = flux_model_dir / "tokenizer"
78- clip_tokenizer = CLIPTokenizer .from_pretrained (clip_tokenizer_path , local_files_only = True )
79- assert isinstance (clip_tokenizer , CLIPTokenizer )
80-
81- # Load the T5 tokenizer.
82- t5_tokenizer_path = flux_model_dir / "tokenizer_2"
83- t5_tokenizer = T5TokenizerFast .from_pretrained (t5_tokenizer_path , local_files_only = True )
84- assert isinstance (t5_tokenizer , T5TokenizerFast )
85-
86- clip_text_encoder_path = flux_model_dir / "text_encoder"
87- t5_text_encoder_path = flux_model_dir / "text_encoder_2"
88- with (
89- context .models .load_local_model (
90- model_path = clip_text_encoder_path , loader = self ._load_flux_text_encoder
91- ) as clip_text_encoder ,
92- context .models .load_local_model (
93- model_path = t5_text_encoder_path , loader = self ._load_flux_text_encoder_2
94- ) as t5_text_encoder ,
95- ):
96- assert isinstance (clip_text_encoder , CLIPTextModel )
97- assert isinstance (t5_text_encoder , T5EncoderModel )
98- pipeline = FluxPipeline (
99- scheduler = None ,
100- vae = None ,
101- text_encoder = clip_text_encoder ,
102- tokenizer = clip_tokenizer ,
103- text_encoder_2 = t5_text_encoder ,
104- tokenizer_2 = t5_tokenizer ,
105- transformer = None ,
106- )
107-
108- # prompt_embeds: T5 embeddings
109- # pooled_prompt_embeds: CLIP embeddings
110- prompt_embeds , pooled_prompt_embeds , text_ids = pipeline .encode_prompt (
111- prompt = self .positive_prompt ,
112- prompt_2 = self .positive_prompt ,
113- device = TorchDevice .choose_torch_device (),
114- max_sequence_length = max_seq_len ,
115- )
116-
117- assert isinstance (prompt_embeds , torch .Tensor )
118- assert isinstance (pooled_prompt_embeds , torch .Tensor )
119- return prompt_embeds , pooled_prompt_embeds
120-
12180 def _run_diffusion (
12281 self ,
12382 context : InvocationContext ,
@@ -199,44 +158,6 @@ def _run_vae_decoding(
199158 assert isinstance (image , Image .Image )
200159 return image
201160
202- @staticmethod
203- def _load_flux_text_encoder (path : Path ) -> CLIPTextModel :
204- model = CLIPTextModel .from_pretrained (path , local_files_only = True )
205- assert isinstance (model , CLIPTextModel )
206- return model
207-
208- def _load_flux_text_encoder_2 (self , path : Path ) -> T5EncoderModel :
209- if self .use_8bit :
210- model_8bit_path = path / "quantized"
211- if model_8bit_path .exists ():
212- # The quantized model exists, load it.
213- # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
214- # something that we should be able to make much faster.
215- q_model = QuantizedModelForTextEncoding .from_pretrained (model_8bit_path )
216-
217- # Access the underlying wrapped model.
218- # We access the wrapped model, even though it is private, because it simplifies the type checking by
219- # always returning a T5EncoderModel from this function.
220- model = q_model ._wrapped
221- else :
222- # The quantized model does not exist yet, quantize and save it.
223- # TODO(ryand): dtype?
224- model = T5EncoderModel .from_pretrained (path , local_files_only = True )
225- assert isinstance (model , T5EncoderModel )
226-
227- q_model = QuantizedModelForTextEncoding .quantize (model , weights = qfloat8 )
228-
229- model_8bit_path .mkdir (parents = True , exist_ok = True )
230- q_model .save_pretrained (model_8bit_path )
231-
232- # (See earlier comment about accessing the wrapped model.)
233- model = q_model ._wrapped
234- else :
235- model = T5EncoderModel .from_pretrained (path , local_files_only = True )
236-
237- assert isinstance (model , T5EncoderModel )
238- return model
239-
240161 def _load_flux_transformer (self , path : Path ) -> FluxTransformer2DModel :
241162 if self .use_8bit :
242163 model_8bit_path = path / "quantized"
0 commit comments