66from transformers import CLIPTextModel , CLIPTokenizer , T5EncoderModel , T5TokenizerFast
77
88from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
9- from invokeai .app .invocations .fields import InputField
10- from invokeai .app .invocations .flux_text_to_image import FLUX_MODELS , QuantizedModelForTextEncoding , TFluxModelKeys
9+ from invokeai .app .invocations .model import CLIPField , T5EncoderField
10+ from invokeai .app .invocations .fields import InputField , FieldDescriptions , Input
11+ from invokeai .app .invocations .flux_text_to_image import FLUX_MODELS , QuantizedModelForTextEncoding
12+ from invokeai .app .invocations .model import CLIPField , T5EncoderField
1113from invokeai .app .invocations .primitives import ConditioningOutput
1214from invokeai .app .services .shared .invocation_context import InvocationContext
1315from invokeai .backend .stable_diffusion .diffusion .conditioning_data import ConditioningFieldData , FLUXConditioningInfo
2224 version = "1.0.0" ,
2325)
2426class FluxTextEncoderInvocation (BaseInvocation ):
25- model : TFluxModelKeys = InputField (description = "The FLUX model to use for text-to-image generation." )
26- use_8bit : bool = InputField (
27- default = False , description = "Whether to quantize the transformer model to 8-bit precision."
27+ clip : CLIPField = InputField (
28+ title = "CLIP" ,
29+ description = FieldDescriptions .clip ,
30+ input = Input .Connection ,
31+ )
32+ t5Encoder : T5EncoderField = InputField (
33+ title = "T5EncoderField" ,
34+ description = FieldDescriptions .t5Encoder ,
35+ input = Input .Connection ,
2836 )
2937 positive_prompt : str = InputField (description = "Positive prompt for text-to-image generation." )
3038
3139 # TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
3240 # compatible with other ConditioningOutputs.
3341 @torch .no_grad ()
3442 def invoke (self , context : InvocationContext ) -> ConditioningOutput :
35- model_path = context .models .download_and_cache_model (FLUX_MODELS [self .model ])
3643
37- t5_embeddings , clip_embeddings = self ._encode_prompt (context , model_path )
44+ t5_embeddings , clip_embeddings = self ._encode_prompt (context )
3845 conditioning_data = ConditioningFieldData (
3946 conditionings = [FLUXConditioningInfo (clip_embeds = clip_embeddings , t5_embeds = t5_embeddings )]
4047 )
4148
4249 conditioning_name = context .conditioning .save (conditioning_data )
4350 return ConditioningOutput .build (conditioning_name )
51+
52+ def _encode_prompt (self , context : InvocationContext ) -> tuple [torch .Tensor , torch .Tensor ]:
53+ # TODO: Determine the T5 max sequence length based on the model.
54+ # if self.model == "flux-schnell":
55+ max_seq_len = 256
56+ # # elif self.model == "flux-dev":
57+ # # max_seq_len = 512
58+ # else:
59+ # raise ValueError(f"Unknown model: {self.model}")
60+
61+ # Load CLIP.
62+ clip_tokenizer_info = context .models .load (self .clip .tokenizer )
63+ clip_text_encoder_info = context .models .load (self .clip .text_encoder )
64+
65+ # Load T5.
66+ t5_tokenizer_info = context .models .load (self .t5Encoder .tokenizer )
67+ t5_text_encoder_info = context .models .load (self .t5Encoder .text_encoder )
4468
45- def _encode_prompt (self , context : InvocationContext , flux_model_dir : Path ) -> tuple [torch .Tensor , torch .Tensor ]:
46- # Determine the T5 max sequence length based on the model.
47- if self .model == "flux-schnell" :
48- max_seq_len = 256
49- # elif self.model == "flux-dev":
50- # max_seq_len = 512
51- else :
52- raise ValueError (f"Unknown model: { self .model } " )
53-
54- # Load the CLIP tokenizer.
55- clip_tokenizer_path = flux_model_dir / "tokenizer"
56- clip_tokenizer = CLIPTokenizer .from_pretrained (clip_tokenizer_path , local_files_only = True )
57- assert isinstance (clip_tokenizer , CLIPTokenizer )
58-
59- # Load the T5 tokenizer.
60- t5_tokenizer_path = flux_model_dir / "tokenizer_2"
61- t5_tokenizer = T5TokenizerFast .from_pretrained (t5_tokenizer_path , local_files_only = True )
62- assert isinstance (t5_tokenizer , T5TokenizerFast )
63-
64- clip_text_encoder_path = flux_model_dir / "text_encoder"
65- t5_text_encoder_path = flux_model_dir / "text_encoder_2"
6669 with (
67- context .models .load_local_model (
68- model_path = clip_text_encoder_path , loader = self ._load_flux_text_encoder
69- ) as clip_text_encoder ,
70- context .models .load_local_model (
71- model_path = t5_text_encoder_path , loader = self ._load_flux_text_encoder_2
72- ) as t5_text_encoder ,
70+ clip_text_encoder_info as clip_text_encoder ,
71+ t5_text_encoder_info as t5_text_encoder ,
72+ clip_tokenizer_info as clip_tokenizer ,
73+ t5_tokenizer_info as t5_tokenizer ,
7374 ):
7475 assert isinstance (clip_text_encoder , CLIPTextModel )
7576 assert isinstance (t5_text_encoder , T5EncoderModel )
77+ assert isinstance (clip_tokenizer , CLIPTokenizer )
78+ assert isinstance (t5_tokenizer , T5TokenizerFast )
79+
7680 pipeline = FluxPipeline (
7781 scheduler = None ,
7882 vae = None ,
@@ -85,7 +89,7 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu
8589
8690 # prompt_embeds: T5 embeddings
8791 # pooled_prompt_embeds: CLIP embeddings
88- prompt_embeds , pooled_prompt_embeds , text_ids = pipeline .encode_prompt (
92+ prompt_embeds , pooled_prompt_embeds , _ = pipeline .encode_prompt (
8993 prompt = self .positive_prompt ,
9094 prompt_2 = self .positive_prompt ,
9195 device = TorchDevice .choose_torch_device (),
@@ -95,41 +99,3 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu
9599 assert isinstance (prompt_embeds , torch .Tensor )
96100 assert isinstance (pooled_prompt_embeds , torch .Tensor )
97101 return prompt_embeds , pooled_prompt_embeds
98-
99- @staticmethod
100- def _load_flux_text_encoder (path : Path ) -> CLIPTextModel :
101- model = CLIPTextModel .from_pretrained (path , local_files_only = True )
102- assert isinstance (model , CLIPTextModel )
103- return model
104-
105- def _load_flux_text_encoder_2 (self , path : Path ) -> T5EncoderModel :
106- if self .use_8bit :
107- model_8bit_path = path / "quantized"
108- if model_8bit_path .exists ():
109- # The quantized model exists, load it.
110- # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
111- # something that we should be able to make much faster.
112- q_model = QuantizedModelForTextEncoding .from_pretrained (model_8bit_path )
113-
114- # Access the underlying wrapped model.
115- # We access the wrapped model, even though it is private, because it simplifies the type checking by
116- # always returning a T5EncoderModel from this function.
117- model = q_model ._wrapped
118- else :
119- # The quantized model does not exist yet, quantize and save it.
120- # TODO(ryand): dtype?
121- model = T5EncoderModel .from_pretrained (path , local_files_only = True )
122- assert isinstance (model , T5EncoderModel )
123-
124- q_model = QuantizedModelForTextEncoding .quantize (model , weights = qfloat8 )
125-
126- model_8bit_path .mkdir (parents = True , exist_ok = True )
127- q_model .save_pretrained (model_8bit_path )
128-
129- # (See earlier comment about accessing the wrapped model.)
130- model = q_model ._wrapped
131- else :
132- model = T5EncoderModel .from_pretrained (path , local_files_only = True )
133-
134- assert isinstance (model , T5EncoderModel )
135- return model
0 commit comments