1313# limitations under the License.
1414
1515import inspect
16- from typing import List , Optional , Union
16+ from typing import Any , List , Optional , Tuple , Union
1717
1818import numpy as np
1919import torch
2020
21+ from ...models import AutoencoderKL
2122from ...schedulers import FlowMatchEulerDiscreteScheduler
2223from ...utils import logging
2324from ...utils .torch_utils import randn_tensor
@@ -103,6 +104,61 @@ def calculate_shift(
103104 return mu
104105
105106
107+ def prepare_latents_img2img (
108+ vae , scheduler , image , timestep , batch_size , num_channels_latents , height , width , dtype , device , generator
109+ ):
110+ if isinstance (generator , list ) and len (generator ) != batch_size :
111+ raise ValueError (
112+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
113+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
114+ )
115+
116+ vae_scale_factor = 2 ** (len (vae .config .block_out_channels ) - 1 )
117+ latent_channels = vae .config .latent_channels
118+
119+ # VAE applies 8x compression on images but we must also account for packing which requires
120+ # latent height and width to be divisible by 2.
121+ height = 2 * (int (height ) // (vae_scale_factor * 2 ))
122+ width = 2 * (int (width ) // (vae_scale_factor * 2 ))
123+ shape = (batch_size , num_channels_latents , height , width )
124+ latent_image_ids = _prepare_latent_image_ids (batch_size , height // 2 , width // 2 , device , dtype )
125+
126+ image = image .to (device = device , dtype = dtype )
127+ if image .shape [1 ] != latent_channels :
128+ image_latents = _encode_vae_image (image = image , generator = generator )
129+ else :
130+ image_latents = image
131+ if batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] == 0 :
132+ # expand init_latents for batch_size
133+ additional_image_per_prompt = batch_size // image_latents .shape [0 ]
134+ image_latents = torch .cat ([image_latents ] * additional_image_per_prompt , dim = 0 )
135+ elif batch_size > image_latents .shape [0 ] and batch_size % image_latents .shape [0 ] != 0 :
136+ raise ValueError (
137+ f"Cannot duplicate `image` of batch size { image_latents .shape [0 ]} to { batch_size } text prompts."
138+ )
139+ else :
140+ image_latents = torch .cat ([image_latents ], dim = 0 )
141+
142+ noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
143+ latents = scheduler .scale_noise (image_latents , timestep , noise )
144+ latents = _pack_latents (latents , batch_size , num_channels_latents , height , width )
145+ return latents , latent_image_ids
146+
147+
148+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
149+ def retrieve_latents (
150+ encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
151+ ):
152+ if hasattr (encoder_output , "latent_dist" ) and sample_mode == "sample" :
153+ return encoder_output .latent_dist .sample (generator )
154+ elif hasattr (encoder_output , "latent_dist" ) and sample_mode == "argmax" :
155+ return encoder_output .latent_dist .mode ()
156+ elif hasattr (encoder_output , "latents" ):
157+ return encoder_output .latents
158+ else :
159+ raise AttributeError ("Could not access latents of provided encoder_output" )
160+
161+
106162def _pack_latents (latents , batch_size , num_channels_latents , height , width ):
107163 latents = latents .view (batch_size , num_channels_latents , height // 2 , 2 , width // 2 , 2 )
108164 latents = latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
@@ -125,6 +181,44 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
125181 return latent_image_ids .to (device = device , dtype = dtype )
126182
127183
184+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae
185+ def _encode_vae_image (vae , image : torch .Tensor , generator : torch .Generator ):
186+ if isinstance (generator , list ):
187+ image_latents = [
188+ retrieve_latents (vae .encode (image [i : i + 1 ]), generator = generator [i ]) for i in range (image .shape [0 ])
189+ ]
190+ image_latents = torch .cat (image_latents , dim = 0 )
191+ else :
192+ image_latents = retrieve_latents (vae .encode (image ), generator = generator )
193+
194+ image_latents = (image_latents - vae .config .shift_factor ) * vae .config .scaling_factor
195+
196+ return image_latents
197+
198+
199+ def _get_timesteps_and_optionals (transformer , scheduler , latents , num_inference_steps , guidance_scale , sigmas , device ):
200+ image_seq_len = latents .shape [1 ]
201+
202+ sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
203+ if hasattr (scheduler .config , "use_flow_sigmas" ) and scheduler .config .use_flow_sigmas :
204+ sigmas = None
205+ mu = calculate_shift (
206+ image_seq_len ,
207+ scheduler .config .get ("base_image_seq_len" , 256 ),
208+ scheduler .config .get ("max_image_seq_len" , 4096 ),
209+ scheduler .config .get ("base_shift" , 0.5 ),
210+ scheduler .config .get ("max_shift" , 1.15 ),
211+ )
212+ timesteps , num_inference_steps = retrieve_timesteps (scheduler , num_inference_steps , device , sigmas = sigmas , mu = mu )
213+ if transformer .config .guidance_embeds :
214+ guidance = torch .full ([1 ], guidance_scale , device = device , dtype = torch .float32 )
215+ guidance = guidance .expand (latents .shape [0 ])
216+ else :
217+ guidance = None
218+
219+ return timesteps , num_inference_steps , sigmas , guidance
220+
221+
128222class FluxInputStep (PipelineBlock ):
129223 model_name = "flux"
130224
@@ -264,34 +358,103 @@ def intermediate_outputs(self) -> List[OutputParam]:
264358 def __call__ (self , components : FluxModularPipeline , state : PipelineState ) -> PipelineState :
265359 block_state = self .get_block_state (state )
266360 block_state .device = components ._execution_device
267- scheduler = components .scheduler
268361
269- latents = block_state . latents
270- image_seq_len = latents . shape [ 1 ]
362+ scheduler = components . scheduler
363+ transformer = components . transformer
271364
272- num_inference_steps = block_state .num_inference_steps
273- sigmas = block_state .sigmas
274- sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
275- if hasattr (scheduler .config , "use_flow_sigmas" ) and scheduler .config .use_flow_sigmas :
276- sigmas = None
277- block_state .sigmas = sigmas
278- mu = calculate_shift (
279- image_seq_len ,
280- scheduler .config .get ("base_image_seq_len" , 256 ),
281- scheduler .config .get ("max_image_seq_len" , 4096 ),
282- scheduler .config .get ("base_shift" , 0.5 ),
283- scheduler .config .get ("max_shift" , 1.15 ),
365+ timesteps , num_inference_steps , sigmas , guidance = _get_timesteps_and_optionals (
366+ transformer ,
367+ scheduler ,
368+ block_state .latents ,
369+ block_state .num_inference_steps ,
370+ block_state .guidance_scale ,
371+ block_state .sigmas ,
372+ block_state .device ,
284373 )
285- block_state .timesteps , block_state .num_inference_steps = retrieve_timesteps (
286- scheduler , block_state .num_inference_steps , block_state .device , sigmas = block_state .sigmas , mu = mu
374+ block_state .timesteps = timesteps
375+ block_state .num_inference_steps = num_inference_steps
376+ block_state .sigmas = sigmas
377+ block_state .guidance = guidance
378+
379+ self .set_block_state (state , block_state )
380+ return components , state
381+
382+
383+ class FluxImg2ImgSetTimestepsStep (PipelineBlock ):
384+ model_name = "flux"
385+
386+ @property
387+ def expected_components (self ) -> List [ComponentSpec ]:
388+ return [ComponentSpec ("scheduler" , FlowMatchEulerDiscreteScheduler )]
389+
390+ @property
391+ def description (self ) -> str :
392+ return "Step that sets the scheduler's timesteps for inference"
393+
394+ @property
395+ def inputs (self ) -> List [InputParam ]:
396+ return [
397+ InputParam ("num_inference_steps" , default = 50 ),
398+ InputParam ("timesteps" ),
399+ InputParam ("sigmas" ),
400+ InputParam ("guidance_scale" , default = 3.5 ),
401+ InputParam ("latents" , type_hint = torch .Tensor ),
402+ InputParam ("num_images_per_prompt" , default = 1 ),
403+ ]
404+
405+ @property
406+ def intermediate_inputs (self ) -> List [str ]:
407+ return [
408+ InputParam (
409+ "latents" ,
410+ required = True ,
411+ type_hint = torch .Tensor ,
412+ description = "The initial latents to use for the denoising process. Can be generated in prepare_latent step." ,
413+ )
414+ ]
415+
416+ @property
417+ def intermediate_outputs (self ) -> List [OutputParam ]:
418+ return [
419+ OutputParam ("timesteps" , type_hint = torch .Tensor , description = "The timesteps to use for inference" ),
420+ OutputParam (
421+ "num_inference_steps" ,
422+ type_hint = int ,
423+ description = "The number of denoising steps to perform at inference time" ,
424+ ),
425+ OutputParam (
426+ "latent_timestep" ,
427+ type_hint = torch .Tensor ,
428+ description = "The timestep that represents the initial noise level for image-to-image generation" ,
429+ ),
430+ OutputParam ("guidance" , type_hint = torch .Tensor , description = "Optional guidance to be used." ),
431+ ]
432+
433+ @torch .no_grad ()
434+ def __call__ (self , components : FluxModularPipeline , state : PipelineState ) -> PipelineState :
435+ block_state = self .get_block_state (state )
436+ block_state .device = components ._execution_device
437+
438+ scheduler = components .scheduler
439+ transformer = components .transformer
440+
441+ timesteps , num_inference_steps , sigmas , guidance = _get_timesteps_and_optionals (
442+ transformer ,
443+ scheduler ,
444+ block_state .latents ,
445+ block_state .num_inference_steps ,
446+ block_state .guidance_scale ,
447+ block_state .sigmas ,
448+ block_state .device ,
287449 )
288- if components .transformer .config .guidance_embeds :
289- guidance = torch .full ([1 ], block_state .guidance_scale , device = block_state .device , dtype = torch .float32 )
290- guidance = guidance .expand (latents .shape [0 ])
291- else :
292- guidance = None
450+ block_state .timesteps = timesteps
451+ block_state .num_inference_steps = num_inference_steps
452+ block_state .sigmas = sigmas
293453 block_state .guidance = guidance
294454
455+ batch_size = block_state .latents .shape [0 ]
456+ block_state .latent_timestep = timesteps [:1 ].repeat (batch_size * block_state .num_images_per_prompt )
457+
295458 self .set_block_state (state , block_state )
296459 return components , state
297460
@@ -418,3 +581,96 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
418581 self .set_block_state (state , block_state )
419582
420583 return components , state
584+
585+
586+ class FluxLImg2ImgPrepareLatentsStep (PipelineBlock ):
587+ model_name = "flux"
588+
589+ @property
590+ def expected_components (self ) -> List [ComponentSpec ]:
591+ return [ComponentSpec ("vae" , AutoencoderKL ), ComponentSpec ("scheduler" , FlowMatchEulerDiscreteScheduler )]
592+
593+ @property
594+ def description (self ) -> str :
595+ return "Step that prepares the latents for the image-to-image generation process"
596+
597+ @property
598+ def inputs (self ) -> List [Tuple [str , Any ]]:
599+ return [
600+ InputParam ("height" , type_hint = int ),
601+ InputParam ("width" , type_hint = int ),
602+ InputParam ("latents" , type_hint = Optional [torch .Tensor ]),
603+ InputParam ("num_images_per_prompt" , type_hint = int , default = 1 ),
604+ InputParam ("latents" ),
605+ ]
606+
607+ @property
608+ def intermediate_inputs (self ) -> List [InputParam ]:
609+ return [
610+ InputParam ("generator" ),
611+ InputParam (
612+ "image_latents" ,
613+ required = True ,
614+ type_hint = torch .Tensor ,
615+ description = "The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step." ,
616+ ),
617+ InputParam (
618+ "latent_timestep" ,
619+ required = True ,
620+ type_hint = torch .Tensor ,
621+ description = "The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step." ,
622+ ),
623+ InputParam (
624+ "batch_size" ,
625+ required = True ,
626+ type_hint = int ,
627+ description = "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step." ,
628+ ),
629+ InputParam ("dtype" , required = True , type_hint = torch .dtype , description = "The dtype of the model inputs" ),
630+ ]
631+
632+ @property
633+ def intermediate_outputs (self ) -> List [OutputParam ]:
634+ return [
635+ OutputParam (
636+ "latents" , type_hint = torch .Tensor , description = "The initial latents to use for the denoising process"
637+ ),
638+ OutputParam (
639+ "latent_image_ids" ,
640+ type_hint = torch .Tensor ,
641+ description = "IDs computed from the image sequence needed for RoPE" ,
642+ ),
643+ ]
644+
645+ @torch .no_grad ()
646+ def __call__ (self , components : FluxModularPipeline , state : PipelineState ) -> PipelineState :
647+ block_state = self .get_block_state (state )
648+
649+ block_state .height = block_state .height or components .default_height
650+ block_state .width = block_state .width or components .default_width
651+ block_state .device = components ._execution_device
652+ block_state .dtype = torch .bfloat16 # TODO: okay to hardcode this?
653+ block_state .num_channels_latents = components .num_channels_latents
654+ block_state .dtype = block_state .dtype if block_state .dtype is not None else components .vae .dtype
655+ block_state .device = components ._execution_device
656+
657+ # TODO: implement `check_inputs`
658+
659+ if block_state .latents is None :
660+ block_state .latents , block_state .latent_image_ids = prepare_latents_img2img (
661+ components .vae ,
662+ components .scheduler ,
663+ block_state .image_latents ,
664+ block_state .latent_timestep ,
665+ block_state .batch_size * block_state .num_images_per_prompt ,
666+ block_state .num_channels_latents ,
667+ block_state .height ,
668+ block_state .width ,
669+ block_state .dtype ,
670+ block_state .device ,
671+ block_state .generator ,
672+ )
673+
674+ self .set_block_state (state , block_state )
675+
676+ return components , state
0 commit comments