1818import  numpy  as  np 
1919import  torch 
2020
21- from  ...configuration_utils  import  FrozenDict 
22- from  ...image_processor  import  VaeImageProcessor 
2321from  ...models  import  QwenImageControlNetModel , QwenImageMultiControlNetModel 
24- from  ...pipelines .qwenimage .pipeline_qwenimage_edit  import  calculate_dimensions 
2522from  ...schedulers  import  FlowMatchEulerDiscreteScheduler 
2623from  ...utils .torch_utils  import  randn_tensor , unwrap_module 
2724from  ..modular_pipeline  import  ModularPipelineBlocks , PipelineState 
@@ -243,45 +240,62 @@ def expected_components(self) -> List[ComponentSpec]:
243240        return  [
244241            ComponentSpec ("scheduler" , FlowMatchEulerDiscreteScheduler ),
245242        ]
246-      
243+ 
247244    @property  
248245    def  inputs (self ) ->  List [InputParam ]:
249246        return  [
250-             InputParam (name = "latents" , required = True , type_hint = torch .Tensor , description = "The initial random noised, can be generated in prepare latent step." ),
251-             InputParam (name = "image_latents" , required = True , type_hint = torch .Tensor , description = "The image latents to use for the denoising process. Can be generated in vae encoder + pack latents step." ),
252-             InputParam (name = "timesteps" , required = True , type_hint = torch .Tensor , description = "The timesteps to use for the denoising process. Can be generated in set_timesteps step." ,),
253-             InputParam (name = "batch_size" , required = True , type_hint = int , description = "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in expand textinput step." ),
247+             InputParam (
248+                 name = "latents" ,
249+                 required = True ,
250+                 type_hint = torch .Tensor ,
251+                 description = "The initial random noised, can be generated in prepare latent step." ,
252+             ),
253+             InputParam (
254+                 name = "image_latents" ,
255+                 required = True ,
256+                 type_hint = torch .Tensor ,
257+                 description = "The image latents to use for the denoising process. Can be generated in vae encoder + pack latents step." ,
258+             ),
259+             InputParam (
260+                 name = "timesteps" ,
261+                 required = True ,
262+                 type_hint = torch .Tensor ,
263+                 description = "The timesteps to use for the denoising process. Can be generated in set_timesteps step." ,
264+             ),
265+             InputParam (
266+                 name = "batch_size" ,
267+                 required = True ,
268+                 type_hint = int ,
269+                 description = "Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in expand textinput step." ,
270+             ),
254271            InputParam (name = "num_images_per_prompt" , required = True ),
255272        ]
256273
257274    @property  
258275    def  intermediate_outputs (self ) ->  List [OutputParam ]:
259276        return  [
260-             OutputParam (name = "initial_noise" , type_hint = torch .Tensor , description = "The initial random noised used for inpainting denoising." ),
277+             OutputParam (
278+                 name = "initial_noise" ,
279+                 type_hint = torch .Tensor ,
280+                 description = "The initial random noised used for inpainting denoising." ,
281+             ),
261282        ]
262-     
263-     
283+ 
264284    @staticmethod  
265285    def  check_inputs (image_latents , latents , batch_size ):
266- 
267286        if  image_latents .shape [0 ] !=  batch_size :
268287            raise  ValueError (
269288                f"`image_latents` must have have batch size { batch_size } { image_latents .shape [0 ]}  
270289            )
271290
272291        if  image_latents .ndim  !=  3 :
273292            raise  ValueError (f"`image_latents` must have 3 dimensions (patchified), but got { image_latents .ndim }  )
274-         
275-         
293+ 
276294        if  latents .shape [0 ] !=  batch_size :
277-             raise  ValueError (
278-                 f"`latents` must have have batch size { batch_size } { latents .shape [0 ]}  
279-             )
280-         
281-     
295+             raise  ValueError (f"`latents` must have have batch size { batch_size } { latents .shape [0 ]}  )
296+ 
282297    @torch .no_grad () 
283298    def  __call__ (self , components : QwenImageModularPipeline , state : PipelineState ) ->  PipelineState :
284-         
285299        block_state  =  self .get_block_state (state )
286300        final_batch_size  =  block_state .batch_size  *  block_state .num_images_per_prompt 
287301
@@ -290,43 +304,52 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
290304            latents = block_state .latents ,
291305            batch_size = final_batch_size ,
292306        )
293-          
307+ 
294308        # prepare latent timestep 
295309        latent_timestep  =  block_state .timesteps [:1 ].repeat (final_batch_size )
296-          
310+ 
297311        # make copy of initial_noise 
298312        block_state .initial_noise  =  block_state .latents 
299313
300314        # scale noise 
301-         block_state .latents  =  components .scheduler .scale_noise (block_state .image_latents , latent_timestep , block_state .latents )
315+         block_state .latents  =  components .scheduler .scale_noise (
316+             block_state .image_latents , latent_timestep , block_state .latents 
317+         )
302318
303319        self .set_block_state (state , block_state )
304-          
305-         return  components , state      
320+ 
321+         return  components , state 
306322
307323
308324class  QwenImageCreateMaskLatentsStep (ModularPipelineBlocks ):
309325    model_name  =  "qwenimage" 
310-      
326+ 
311327    @property  
312328    def  description (self ) ->  str :
313329        return  "Step that create the mask latents for the inpainting process. Should be run with the pachify latents step." 
314-      
330+ 
315331    @property  
316332    def  inputs (self ) ->  List [InputParam ]:
317333        return  [
318-             InputParam (name = "mask_image" , required = True , type_hint = torch .Tensor , description = "The mask to use for the inpainting process." ),
334+             InputParam (
335+                 name = "mask_image" ,
336+                 required = True ,
337+                 type_hint = torch .Tensor ,
338+                 description = "The mask to use for the inpainting process." ,
339+             ),
319340            InputParam (name = "height" , required = True ),
320341            InputParam (name = "width" , required = True ),
321342            InputParam (name = "dtype" , required = True ),
322343        ]
323-      
344+ 
324345    @property  
325346    def  intermediate_outputs (self ) ->  List [OutputParam ]:
326347        return  [
327-             OutputParam (name = "mask" , type_hint = torch .Tensor , description = "The mask to use for the inpainting process." ),
348+             OutputParam (
349+                 name = "mask" , type_hint = torch .Tensor , description = "The mask to use for the inpainting process." 
350+             ),
328351        ]
329-      
352+ 
330353    @torch .no_grad () 
331354    def  __call__ (self , components : QwenImageModularPipeline , state : PipelineState ) ->  PipelineState :
332355        block_state  =  self .get_block_state (state )
@@ -342,14 +365,14 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
342365        block_state .mask  =  torch .nn .functional .interpolate (
343366            block_state .mask_image ,
344367            size = (height_latents , width_latents ),
345-         )  
368+         )
346369
347370        block_state .mask  =  block_state .mask .unsqueeze (2 )
348371        block_state .mask  =  block_state .mask .repeat (1 , components .num_channels_latents , 1 , 1 , 1 )
349372        block_state .mask  =  block_state .mask .to (device = device , dtype = block_state .dtype )
350-          
373+ 
351374        self .set_block_state (state , block_state )
352-          
375+ 
353376        return  components , state 
354377
355378
@@ -381,14 +404,14 @@ def __init__(self, input_names: List[str] = ["image_latents"]):
381404            input_names  =  [input_names ]
382405        self ._latents_input_names  =  input_names 
383406        super ().__init__ ()
384-      
407+ 
385408    @staticmethod  
386409    def  check_input_shape (latents_input , latents_input_name , batch_size ):
387410        if  latents_input  is  not None  and  latents_input .shape [0 ] !=  1  and  latents_input .shape [0 ] !=  batch_size :
388411            raise  ValueError (
389412                f"`{ latents_input_name } { batch_size } { latents_input .shape [0 ]}  
390413            )
391-          
414+ 
392415        if  latents_input .ndim  !=  5  and  latents_input .ndim  !=  4 :
393416            raise  ValueError (f"`{ latents_input_name } { latents_input .ndim }  )
394417
@@ -526,11 +549,12 @@ def inputs(self) -> List[InputParam]:
526549    def  intermediate_outputs (self ) ->  List [OutputParam ]:
527550        return  [
528551            OutputParam (
529-                 name = "timesteps" , type_hint = torch .Tensor , description = "The timesteps to use for the denoising process. Can be generated in set_timesteps step." 
552+                 name = "timesteps" ,
553+                 type_hint = torch .Tensor ,
554+                 description = "The timesteps to use for the denoising process. Can be generated in set_timesteps step." ,
530555            ),
531556        ]
532557
533- 
534558    def  __call__ (self , components : QwenImageModularPipeline , state : PipelineState ) ->  PipelineState :
535559        block_state  =  self .get_block_state (state )
536560
@@ -609,9 +633,14 @@ def intermediate_outputs(self) -> List[OutputParam]:
609633    def  __call__ (self , components : QwenImageModularPipeline , state : PipelineState ) ->  PipelineState :
610634        block_state  =  self .get_block_state (state )
611635
612- 
613636        block_state .img_shapes  =  [
614-             [(1 , block_state .height  //  components .vae_scale_factor  //  2 , block_state .width  //  components .vae_scale_factor  //  2 )]
637+             [
638+                 (
639+                     1 ,
640+                     block_state .height  //  components .vae_scale_factor  //  2 ,
641+                     block_state .width  //  components .vae_scale_factor  //  2 ,
642+                 )
643+             ]
615644            *  block_state .batch_size 
616645        ]
617646        block_state .txt_seq_lens  =  (
0 commit comments