@@ -105,26 +105,26 @@ def calculate_shift(
105105
106106# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents 
107107def  _pack_latents (latents , batch_size , num_channels_latents , height , width ):
108-          latents  =  latents .view (batch_size , num_channels_latents , height  //  2 , 2 , width  //  2 , 2 )
109-          latents  =  latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
110-          latents  =  latents .reshape (batch_size , (height  //  2 ) *  (width  //  2 ), num_channels_latents  *  4 )
108+     latents  =  latents .view (batch_size , num_channels_latents , height  //  2 , 2 , width  //  2 , 2 )
109+     latents  =  latents .permute (0 , 2 , 4 , 1 , 3 , 5 )
110+     latents  =  latents .reshape (batch_size , (height  //  2 ) *  (width  //  2 ), num_channels_latents  *  4 )
111111
112-          return  latents 
112+     return  latents 
113113
114114
115115# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids 
116116def  _prepare_latent_image_ids (batch_size , height , width , device , dtype ):
117-          latent_image_ids  =  torch .zeros (height , width , 3 )
118-          latent_image_ids [..., 1 ] =  latent_image_ids [..., 1 ] +  torch .arange (height )[:, None ]
119-          latent_image_ids [..., 2 ] =  latent_image_ids [..., 2 ] +  torch .arange (width )[None , :]
117+     latent_image_ids  =  torch .zeros (height , width , 3 )
118+     latent_image_ids [..., 1 ] =  latent_image_ids [..., 1 ] +  torch .arange (height )[:, None ]
119+     latent_image_ids [..., 2 ] =  latent_image_ids [..., 2 ] +  torch .arange (width )[None , :]
120120
121-          latent_image_id_height , latent_image_id_width , latent_image_id_channels  =  latent_image_ids .shape 
121+     latent_image_id_height , latent_image_id_width , latent_image_id_channels  =  latent_image_ids .shape 
122122
123-          latent_image_ids  =  latent_image_ids .reshape (
124-              latent_image_id_height  *  latent_image_id_width , latent_image_id_channels 
125-          )
123+     latent_image_ids  =  latent_image_ids .reshape (
124+         latent_image_id_height  *  latent_image_id_width , latent_image_id_channels 
125+     )
126126
127-          return  latent_image_ids .to (device = device , dtype = dtype )
127+     return  latent_image_ids .to (device = device , dtype = dtype )
128128
129129
130130class  FluxInputStep (PipelineBlock ):
@@ -180,13 +180,11 @@ def intermediate_outputs(self) -> List[str]:
180180            OutputParam (
181181                "prompt_embeds" ,
182182                type_hint = torch .Tensor ,
183-                 # kwargs_type="guider_input_fields",  # already in intermedites state but declare here again for guider_input_fields 
184183                description = "text embeddings used to guide the image generation" ,
185184            ),
186185            OutputParam (
187186                "pooled_prompt_embeds" ,
188187                type_hint = torch .Tensor ,
189-                 # kwargs_type="guider_input_fields",  # already in intermedites state but declare here again for guider_input_fields 
190188                description = "pooled text embeddings used to guide the image generation" ,
191189            ),
192190            # TODO: support negative embeddings? 
@@ -235,10 +233,10 @@ def description(self) -> str:
235233    def  inputs (self ) ->  List [InputParam ]:
236234        return  [
237235            InputParam ("num_inference_steps" , default = 50 ),
238-             InputParam ("timesteps" ),  
236+             InputParam ("timesteps" ),
239237            InputParam ("sigmas" ),
240238            InputParam ("guidance_scale" , default = 3.5 ),
241-             InputParam ("latents" , type_hint = torch .Tensor )
239+             InputParam ("latents" , type_hint = torch .Tensor ), 
242240        ]
243241
244242    @property  
@@ -261,7 +259,7 @@ def intermediate_outputs(self) -> List[OutputParam]:
261259                type_hint = int ,
262260                description = "The number of denoising steps to perform at inference time" ,
263261            ),
264-             OutputParam ("guidance" , type_hint = torch .Tensor , description = "Optional guidance to be used." )
262+             OutputParam ("guidance" , type_hint = torch .Tensor , description = "Optional guidance to be used." ), 
265263        ]
266264
267265    @torch .no_grad () 
@@ -340,10 +338,11 @@ def intermediate_outputs(self) -> List[OutputParam]:
340338                "latents" , type_hint = torch .Tensor , description = "The initial latents to use for the denoising process" 
341339            ),
342340            OutputParam (
343-                 "latent_image_ids" , type_hint = torch .Tensor , description = "IDs computed from the image sequence needed for RoPE" 
344-             )
341+                 "latent_image_ids" ,
342+                 type_hint = torch .Tensor ,
343+                 description = "IDs computed from the image sequence needed for RoPE" ,
344+             ),
345345        ]
346-         
347346
348347    @staticmethod  
349348    def  check_inputs (components , block_state ):
@@ -417,7 +416,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
417416            block_state .generator ,
418417            block_state .latents ,
419418        )
420-          
419+ 
421420        self .set_block_state (state , block_state )
422421
423422        return  components , state 
0 commit comments