@@ -191,6 +191,7 @@ def __init__(
191191        transformer : FluxTransformer2DModel ,
192192        image_encoder : CLIPVisionModelWithProjection  =  None ,
193193        feature_extractor : CLIPImageProcessor  =  None ,
194+         variant : str  =  "flux" ,
194195    ):
195196        super ().__init__ ()
196197
@@ -213,6 +214,17 @@ def __init__(
213214            self .tokenizer .model_max_length  if  hasattr (self , "tokenizer" ) and  self .tokenizer  is  not None  else  77 
214215        )
215216        self .default_sample_size  =  128 
217+         if  variant  not  in "flux" , "chroma" }:
218+             raise  ValueError ("`variant` must be `'flux' or `'chroma'`." )
219+ 
220+         self .variant  =  variant 
221+ 
222+     def  _get_chroma_attn_mask (self , length : torch .Tensor , max_sequence_length : int ) ->  torch .Tensor :
223+         attention_mask  =  torch .zeros ((length .shape [0 ], max_sequence_length ), dtype = torch .bool , device = length .device )
224+         for  i , n_tokens  in  enumerate (length ):
225+             n_tokens  =  torch .max (n_tokens  +  1 , max_sequence_length )
226+             attention_mask [i , :n_tokens ] =  1 
227+         return  attention_mask 
216228
217229    def  _get_t5_prompt_embeds (
218230        self ,
@@ -236,7 +248,7 @@ def _get_t5_prompt_embeds(
236248            padding = "max_length" ,
237249            max_length = max_sequence_length ,
238250            truncation = True ,
239-             return_length = False ,
251+             return_length = ( self . variant   ==   "chroma" ) ,
240252            return_overflowing_tokens = False ,
241253            return_tensors = "pt" ,
242254        )
@@ -250,7 +262,15 @@ def _get_t5_prompt_embeds(
250262                f" { max_sequence_length } { removed_text }  
251263            )
252264
253-         prompt_embeds  =  self .text_encoder_2 (text_input_ids .to (device ), output_hidden_states = False )[0 ]
265+         prompt_embeds  =  self .text_encoder_2 (
266+             text_input_ids .to (device ),
267+             output_hidden_states = False ,
268+             attention_mask = (
269+                 self ._get_chroma_attn_mask (text_inputs .length , max_sequence_length ).to (device )
270+                 if  self .variant  ==  "chroma" 
271+                 else  None 
272+             ),
273+         )[0 ]
254274
255275        dtype  =  self .text_encoder_2 .dtype 
256276        prompt_embeds  =  prompt_embeds .to (dtype = dtype , device = device )
0 commit comments