@@ -191,7 +191,6 @@ def __init__(
191191 transformer : FluxTransformer2DModel ,
192192 image_encoder : CLIPVisionModelWithProjection = None ,
193193 feature_extractor : CLIPImageProcessor = None ,
194- variant : str = "flux" ,
195194 ):
196195 super ().__init__ ()
197196
@@ -214,17 +213,6 @@ def __init__(
214213 self .tokenizer .model_max_length if hasattr (self , "tokenizer" ) and self .tokenizer is not None else 77
215214 )
216215 self .default_sample_size = 128
217- if variant not in {"flux" , "chroma" }:
218- raise ValueError ("`variant` must be `'flux' or `'chroma'`." )
219-
220- self .variant = variant
221-
222- def _get_chroma_attn_mask (self , length : torch .Tensor , max_sequence_length : int ) -> torch .Tensor :
223- attention_mask = torch .zeros ((length .shape [0 ], max_sequence_length ), dtype = torch .bool , device = length .device )
224- for i , n_tokens in enumerate (length ):
225- n_tokens = torch .max (n_tokens + 1 , max_sequence_length )
226- attention_mask [i , :n_tokens ] = True
227- return attention_mask
228216
229217 def _get_t5_prompt_embeds (
230218 self ,
@@ -248,7 +236,7 @@ def _get_t5_prompt_embeds(
248236 padding = "max_length" ,
249237 max_length = max_sequence_length ,
250238 truncation = True ,
251- return_length = True ,
239+ return_length = False ,
252240 return_overflowing_tokens = False ,
253241 return_tensors = "pt" ,
254242 )
@@ -262,10 +250,7 @@ def _get_t5_prompt_embeds(
262250 f" { max_sequence_length } tokens: { removed_text } "
263251 )
264252
265- text_inputs .attention_mask [:, : text_inputs .length + 1 ] = 1.0
266- prompt_embeds = self .text_encoder_2 (
267- text_input_ids .to (device ), output_hidden_states = False , attention_mask = text_inputs .attention_mask .to (device )
268- )[0 ]
253+ prompt_embeds = self .text_encoder_2 (text_input_ids .to (device ), output_hidden_states = False )[0 ]
269254
270255 dtype = self .text_encoder_2 .dtype
271256 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
@@ -702,11 +687,11 @@ def __call__(
702687 their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
703688 will be used.
704689 guidance_scale (`float`, *optional*, defaults to 3.5):
705- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
706- `guidance_scale` is defined as `w` of equation 2. of [Imagen
707- Paper](https://arxiv.org/pdf /2205.11487.pdf ). Guidance scale is enabled by setting `guidance_scale >
708- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
709- usually at the expense of lower image quality.
690+ Guidance scale as defined in [Classifier-Free Diffusion
691+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
692+ of [Imagen Paper](https://huggingface.co/papers /2205.11487). Guidance scale is enabled by setting
693+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
694+ the text `prompt`, usually at the expense of lower image quality.
710695 num_images_per_prompt (`int`, *optional*, defaults to 1):
711696 The number of images to generate per prompt.
712697 generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -715,7 +700,7 @@ def __call__(
715700 latents (`torch.FloatTensor`, *optional*):
716701 Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
717702 generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
718- tensor will ge generated by sampling using the supplied random `generator`.
703+ tensor will be generated by sampling using the supplied random `generator`.
719704 prompt_embeds (`torch.FloatTensor`, *optional*):
720705 Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
721706 provided, text embeddings will be generated from `prompt` input argument.
0 commit comments