@@ -230,6 +230,7 @@ def __init__(
230230        )
231231        self .default_sample_size  =  128 
232232
233+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds 
233234    def  _get_t5_prompt_embeds (
234235        self ,
235236        prompt : Union [str , List [str ]] =  None ,
@@ -279,6 +280,7 @@ def _get_t5_prompt_embeds(
279280
280281        return  prompt_embeds 
281282
283+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds 
282284    def  _get_clip_prompt_embeds (
283285        self ,
284286        prompt : Union [str , List [str ]],
@@ -323,6 +325,7 @@ def _get_clip_prompt_embeds(
323325
324326        return  prompt_embeds 
325327
328+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt 
326329    def  encode_prompt (
327330        self ,
328331        prompt : Union [str , List [str ]],
@@ -513,6 +516,7 @@ def encode_text_with_editing(
513516            enabled_editing_prompts ,
514517        )
515518
519+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image 
516520    def  encode_image (self , image , device , num_images_per_prompt ):
517521        dtype  =  next (self .image_encoder .parameters ()).dtype 
518522
@@ -524,6 +528,7 @@ def encode_image(self, image, device, num_images_per_prompt):
524528        image_embeds  =  image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
525529        return  image_embeds 
526530
531+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds 
527532    def  prepare_ip_adapter_image_embeds (
528533        self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt 
529534    ):
@@ -555,6 +560,7 @@ def prepare_ip_adapter_image_embeds(
555560
556561        return  ip_adapter_image_embeds 
557562
563+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs 
558564    def  check_inputs (
559565        self ,
560566        prompt ,
@@ -633,6 +639,7 @@ def check_inputs(
633639            raise  ValueError (f"`max_sequence_length` cannot be greater than 512 but is { max_sequence_length }  )
634640
635641    @staticmethod  
642+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids 
636643    def  _prepare_latent_image_ids (batch_size , height , width , device , dtype ):
637644        latent_image_ids  =  torch .zeros (height , width , 3 )
638645        latent_image_ids [..., 1 ] =  latent_image_ids [..., 1 ] +  torch .arange (height )[:, None ]
@@ -647,6 +654,7 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
647654        return  latent_image_ids .to (device = device , dtype = dtype )
648655
649656    @staticmethod  
657+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents 
650658    def  _pack_latents (latents , batch_size , num_channels_latents , height , width ):
651659        latents  =  latents .view (batch_size , num_channels_latents , height  //  2 , 2 , width  //  2 , 2 )
652660        latents  =  latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
@@ -655,6 +663,7 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
655663        return  latents 
656664
657665    @staticmethod  
666+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents 
658667    def  _unpack_latents (latents , height , width , vae_scale_factor ):
659668        batch_size , num_patches , channels  =  latents .shape 
660669
@@ -670,20 +679,23 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
670679
671680        return  latents 
672681
682+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing 
673683    def  enable_vae_slicing (self ):
674684        r""" 
675685        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 
676686        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 
677687        """ 
678688        self .vae .enable_slicing ()
679689
690+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing 
680691    def  disable_vae_slicing (self ):
681692        r""" 
682693        Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to 
683694        computing decoding in one step. 
684695        """ 
685696        self .vae .disable_slicing ()
686697
698+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling 
687699    def  enable_vae_tiling (self ):
688700        r""" 
689701        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 
@@ -692,13 +704,15 @@ def enable_vae_tiling(self):
692704        """ 
693705        self .vae .enable_tiling ()
694706
707+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling 
695708    def  disable_vae_tiling (self ):
696709        r""" 
697710        Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to 
698711        computing decoding in one step. 
699712        """ 
700713        self .vae .disable_tiling ()
701714
715+     # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents 
702716    def  prepare_latents (
703717        self ,
704718        batch_size ,
@@ -1171,7 +1185,7 @@ def __call__(
11711185                        device = device ,
11721186                        dtype = noise_guidance .dtype ,
11731187                    )
1174-                      # noise_guidance_edit = torch.zeros_like(noise_guidance) 
1188+ 
11751189                    warmup_inds  =  []
11761190                    for  c , noise_pred_edit_concept  in  enumerate (noise_pred_edit_concepts ):
11771191                        if  isinstance (edit_guidance_scale , list ):
@@ -1244,9 +1258,6 @@ def __call__(
12441258                        )
12451259
12461260                        noise_guidance_edit [c , :, :, :] =  noise_guidance_edit_tmp 
1247-                         # noise_guidance_edit[c] = noise_guidance_edit_tmp 
1248- 
1249-                         # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp 
12501261
12511262                    warmup_inds  =  torch .tensor (warmup_inds ).to (device )
12521263                    if  len (noise_pred_edit_concepts ) >  warmup_inds .shape [0 ] >  0 :
@@ -1258,7 +1269,6 @@ def __call__(
12581269                            concept_weights_tmp  <  0 , torch .zeros_like (concept_weights_tmp ), concept_weights_tmp 
12591270                        )
12601271                        concept_weights_tmp  =  concept_weights_tmp  /  concept_weights_tmp .sum (dim = 0 )
1261-                         # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) 
12621272
12631273                        noise_guidance_edit_tmp  =  torch .index_select (noise_guidance_edit .to (device ), 0 , warmup_inds )
12641274                        noise_guidance_edit_tmp  =  torch .einsum (
0 commit comments