@@ -348,9 +348,14 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
348348        return  components , state 
349349
350350
351- class  QwenImageVaeEncoderStep (ModularPipelineBlocks ):
351+ class  QwenImageVaeEncoderDynamicStep (ModularPipelineBlocks ):
352352    model_name  =  "qwenimage" 
353353
354+     def  __init__ (self , input_name : str  =  "image" , output_name : str  =  "image_latents" ):
355+         self .input_name  =  input_name 
356+         self .output_name  =  output_name 
357+         super ().__init__ ()
358+ 
354359    @property  
355360    def  description (self ) ->  str :
356361        return  "Vae Encoder step that encode the input image into a latent representation" 
@@ -370,15 +375,15 @@ def expected_components(self) -> List[ComponentSpec]:
370375    @property  
371376    def  inputs (self ) ->  List [InputParam ]:
372377        return  [
373-             InputParam ("image" , required = True ,  description = "The image to encode, should already be resized using resize step" ),
378+             InputParam (self . input_name , required = True ),
374379            InputParam ("generator" ),
375380        ]
376381
377382    @property  
378383    def  intermediate_outputs (self ) ->  List [OutputParam ]:
379384        return  [
380385            OutputParam (
381-                 "image_latents" ,
386+                 self . output_name ,
382387                type_hint = torch .Tensor ,
383388                description = "The latents representing the reference image" ,
384389            )
@@ -391,16 +396,20 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
391396        device  =  components ._execution_device 
392397        dtype  =  components .vae .dtype 
393398
394-         image  =  components .image_processor .preprocess (block_state .image )
399+         image  =  getattr (block_state , self .input_name )
400+ 
401+         image  =  components .image_processor .preprocess (image )
395402        image  =  image .unsqueeze (2 )
396403        image  =  image .to (device = device , dtype = dtype )
397404
398405
399406        # Encode image into latents 
400-         block_state . image_latents  =  encode_vae_image (
407+         image_latents  =  encode_vae_image (
401408            image = image , vae = components .vae , generator = block_state .generator , latent_channels = components .num_channels_latents 
402409        )
403410
411+         setattr (block_state , self .output_name , image_latents )
412+ 
404413        self .set_block_state (state , block_state )
405414
406415        return  components , state 
0 commit comments