@@ -308,7 +308,6 @@ def __init__(
308308 vae_decoder : FluxVAEDecoder ,
309309 vae_encoder : FluxVAEEncoder ,
310310 load_text_encoder : bool = True ,
311- use_cfg : bool = False ,
312311 batch_cfg : bool = False ,
313312 vae_tiled : bool = False ,
314313 vae_tile_size : int = 256 ,
@@ -336,7 +335,6 @@ def __init__(
336335 self .vae_decoder = vae_decoder
337336 self .vae_encoder = vae_encoder
338337 self .load_text_encoder = load_text_encoder
339- self .use_cfg = use_cfg
340338 self .batch_cfg = batch_cfg
341339 self .ip_adapter = None
342340 self .redux = None
@@ -353,11 +351,15 @@ def __init__(
353351 def from_pretrained (
354352 cls ,
355353 model_path_or_config : str | os .PathLike | FluxModelConfig ,
354+ load_text_encoder : bool = True ,
355+ batch_cfg : bool = False ,
356+ vae_tiled : bool = False ,
357+ vae_tile_size : int = 256 ,
358+ vae_tile_stride : int = 256 ,
356359 control_type : ControlType = ControlType .normal ,
357360 device : str = "cuda:0" ,
358361 dtype : torch .dtype = torch .bfloat16 ,
359362 offload_mode : str | None = None ,
360- load_text_encoder : bool = True ,
361363 parallelism : int = 1 ,
362364 use_cfg_parallel : bool = False ,
363365 ) -> "FluxImagePipeline" :
@@ -454,6 +456,10 @@ def from_pretrained(
454456 vae_decoder = vae_decoder ,
455457 vae_encoder = vae_encoder ,
456458 load_text_encoder = load_text_encoder ,
459+ batch_cfg = batch_cfg ,
460+ vae_tiled = vae_tiled ,
461+ vae_tile_size = vae_tile_size ,
462+ vae_tile_stride = vae_tile_stride ,
457463 control_type = control_type ,
458464 device = device ,
459465 dtype = dtype ,
@@ -530,10 +536,9 @@ def predict_noise_with_cfg(
530536 controlnet_params : List [ControlNetParams ],
531537 current_step : int ,
532538 total_step : int ,
533- use_cfg : bool = False ,
534539 batch_cfg : bool = False ,
535540 ):
536- if cfg_scale <= 1.0 or not use_cfg :
541+ if cfg_scale <= 1.0 :
537542 return self .predict_noise (
538543 latents ,
539544 timestep ,
@@ -583,6 +588,10 @@ def predict_noise_with_cfg(
583588 add_text_embeds = torch .cat ([positive_add_text_embeds , negative_add_text_embeds ], dim = 0 )
584589 latents = torch .cat ([latents , latents ], dim = 0 )
585590 timestep = torch .cat ([timestep , timestep ], dim = 0 )
591+ image_emb = torch .cat ([image_emb , image_emb ], dim = 0 ) if image_emb is not None else None
592+ image_ids = torch .cat ([image_ids , image_ids ], dim = 0 )
593+ text_ids = torch .cat ([text_ids , text_ids ], dim = 0 )
594+ guidance = torch .cat ([guidance , guidance ], dim = 0 )
586595 positive_noise_pred , negative_noise_pred = self .predict_noise (
587596 latents ,
588597 timestep ,
@@ -676,8 +685,14 @@ def prepare_latents(
676685 num_inference_steps , mu = mu , sigma_min = 1 / num_inference_steps , sigma_max = 1.0
677686 )
678687 init_latents = latents .clone ()
679- sigmas , timesteps = sigmas .to (device = self .device , dtype = self .dtype ), timesteps .to (device = self .device , dtype = self .dtype )
680- init_latents , latents = init_latents .to (device = self .device , dtype = self .dtype ), latents .to (device = self .device , dtype = self .dtype )
688+ sigmas , timesteps = (
689+ sigmas .to (device = self .device , dtype = self .dtype ),
690+ timesteps .to (device = self .device , dtype = self .dtype ),
691+ )
692+ init_latents , latents = (
693+ init_latents .to (device = self .device , dtype = self .dtype ),
694+ latents .to (device = self .device , dtype = self .dtype ),
695+ )
681696 return init_latents , latents , sigmas , timesteps
682697
683698 def prepare_masked_latent (self , image : Image .Image , mask : Image .Image | None , height : int , width : int ):
@@ -826,7 +841,7 @@ def __call__(
826841 # Encode prompts
827842 self .load_models_to_device (["text_encoder_1" , "text_encoder_2" ])
828843 positive_prompt_emb , positive_add_text_embeds = self .encode_prompt (prompt , clip_skip = clip_skip )
829- if self . use_cfg and cfg_scale > 1 :
844+ if cfg_scale > 1 :
830845 negative_prompt_emb , negative_add_text_embeds = self .encode_prompt (negative_prompt , clip_skip = clip_skip )
831846 else :
832847 negative_prompt_emb , negative_add_text_embeds = None , None
@@ -868,7 +883,6 @@ def __call__(
868883 controlnet_params = controlnet_params ,
869884 current_step = i ,
870885 total_step = len (timesteps ),
871- use_cfg = self .use_cfg ,
872886 batch_cfg = self .batch_cfg ,
873887 )
874888 # Denoise
0 commit comments