@@ -118,15 +118,6 @@ def retrieve_latents(
118118 raise AttributeError ("Could not access latents of provided encoder_output" )
119119
120120
121- # TODO: align this with Qwen patchifier
122- def _pack_latents (latents , batch_size , num_channels_latents , height , width ):
123- latents = latents .view (batch_size , num_channels_latents , height // 2 , 2 , width // 2 , 2 )
124- latents = latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
125- latents = latents .reshape (batch_size , (height // 2 ) * (width // 2 ), num_channels_latents * 4 )
126-
127- return latents
128-
129-
130121def _get_initial_timesteps_and_optionals (
131122 transformer ,
132123 scheduler ,
@@ -398,16 +389,15 @@ def prepare_latents(
398389 f" size of { batch_size } . Make sure the batch size matches the length of the generators."
399390 )
400391
401- # TODO: move packing latents code to a patchifier
392+ # TODO: move packing latents code to a patchifier similar to Qwen
402393 latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
403- latents = _pack_latents (latents , batch_size , num_channels_latents , height , width )
394+ latents = FluxPipeline . _pack_latents (latents , batch_size , num_channels_latents , height , width )
404395
405396 return latents
406397
407398 @torch .no_grad ()
408399 def __call__ (self , components : FluxModularPipeline , state : PipelineState ) -> PipelineState :
409400 block_state = self .get_block_state (state )
410-
411401 block_state .height = block_state .height or components .default_height
412402 block_state .width = block_state .width or components .default_width
413403 block_state .device = components ._execution_device
@@ -557,3 +547,73 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
557547 self .set_block_state (state , block_state )
558548
559549 return components , state
550+
551+
552+ class FluxKontextRoPEInputsStep (ModularPipelineBlocks ):
553+ model_name = "flux-kontext"
554+
555+ @property
556+ def description (self ) -> str :
557+ return "Step that prepares the RoPE inputs for the denoising process of Flux Kontext. Should be placed after text encoder and latent preparation steps."
558+
559+ @property
560+ def inputs (self ) -> List [InputParam ]:
561+ return [
562+ InputParam (name = "image_height" ),
563+ InputParam (name = "image_width" ),
564+ InputParam (name = "height" ),
565+ InputParam (name = "width" ),
566+ InputParam (name = "prompt_embeds" ),
567+ ]
568+
569+ @property
570+ def intermediate_outputs (self ) -> List [OutputParam ]:
571+ return [
572+ OutputParam (
573+ name = "txt_ids" ,
574+ kwargs_type = "denoiser_input_fields" ,
575+ type_hint = List [int ],
576+ description = "The sequence lengths of the prompt embeds, used for RoPE calculation." ,
577+ ),
578+ OutputParam (
579+ name = "img_ids" ,
580+ kwargs_type = "denoiser_input_fields" ,
581+ type_hint = List [int ],
582+ description = "The sequence lengths of the image latents, used for RoPE calculation." ,
583+ ),
584+ ]
585+
586+ def __call__ (self , components : FluxModularPipeline , state : PipelineState ) -> PipelineState :
587+ block_state = self .get_block_state (state )
588+
589+ prompt_embeds = block_state .prompt_embeds
590+ device , dtype = prompt_embeds .device , prompt_embeds .dtype
591+ block_state .txt_ids = torch .zeros (prompt_embeds .shape [1 ], 3 ).to (
592+ device = prompt_embeds .device , dtype = prompt_embeds .dtype
593+ )
594+
595+ img_ids = None
596+ if (
597+ getattr (block_state , "image_height" , None ) is not None
598+ and getattr (block_state , "image_width" , None ) is not None
599+ ):
600+ image_latent_height = 2 * (int (block_state .image_height ) // (components .vae_scale_factor * 2 ))
601+ image_latent_width = 2 * (int (block_state .width ) // (components .vae_scale_factor * 2 ))
602+ img_ids = FluxPipeline ._prepare_latent_image_ids (
603+ None , image_latent_height // 2 , image_latent_width // 2 , device , dtype
604+ )
605+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
606+ img_ids [..., 0 ] = 1
607+
608+ height = 2 * (int (block_state .height ) // (components .vae_scale_factor * 2 ))
609+ width = 2 * (int (block_state .width ) // (components .vae_scale_factor * 2 ))
610+ latent_ids = FluxPipeline ._prepare_latent_image_ids (None , height // 2 , width // 2 , device , dtype )
611+
612+ if img_ids is not None :
613+ latent_ids = torch .cat ([latent_ids , img_ids ], dim = 0 )
614+
615+ block_state .img_ids = latent_ids
616+
617+ self .set_block_state (state , block_state )
618+
619+ return components , state
0 commit comments