@@ -191,9 +191,7 @@ def __init__(
191191        text_encoder : Qwen2_5_VLForConditionalGeneration ,
192192        tokenizer : Qwen2Tokenizer ,
193193        transformer : QwenImageTransformer2DModel ,
194-         controlnet : Union [
195-             QwenImageControlNetModel , QwenImageMultiControlNetModel 
196-         ],
194+         controlnet : Union [QwenImageControlNetModel , QwenImageMultiControlNetModel ],
197195    ):
198196        super ().__init__ ()
199197
@@ -701,7 +699,7 @@ def __call__(
701699                height = control_image .shape [3 ],
702700                width = control_image .shape [4 ],
703701            ).to (dtype = prompt_embeds .dtype , device = device )
704-          
702+ 
705703        else :
706704            if  isinstance (self .controlnet , QwenImageMultiControlNetModel ):
707705                control_images  =  []
@@ -723,12 +721,12 @@ def __call__(
723721
724722                    # vae encode 
725723                    self .vae_scale_factor  =  2  **  len (self .vae .temperal_downsample )
726-                     latents_mean  =  (torch . tensor ( self . vae . config . latents_mean ). view ( 1 ,  self . vae . config . z_dim ,  1 ,  1 ,  1 )). to ( 
727-                         device 
728-                     )
729-                     latents_std  =  1.0  /  torch .tensor (self .vae .config .latents_std ).view (1 ,  self . vae . config . z_dim ,  1 ,  1 ,  1 ). to ( 
730-                         device 
731-                     )
724+                     latents_mean  =  (
725+                         torch . tensor ( self . vae . config . latents_mean ). view ( 1 ,  self . vae . config . z_dim ,  1 ,  1 ,  1 ) 
726+                     ). to ( device ) 
727+                     latents_std  =  1.0  /  torch .tensor (self .vae .config .latents_std ).view (
728+                         1 ,  self . vae . config . z_dim ,  1 ,  1 ,  1 
729+                     ). to ( device ) 
732730
733731                    control_image_  =  retrieve_latents (self .vae .encode (control_image_ ), generator = generator )
734732                    control_image_  =  (control_image_  -  latents_mean ) *  latents_std 
@@ -818,7 +816,7 @@ def __call__(
818816                    if  isinstance (controlnet_cond_scale , list ):
819817                        controlnet_cond_scale  =  controlnet_cond_scale [0 ]
820818                    cond_scale  =  controlnet_cond_scale  *  controlnet_keep [i ]
821-                  
819+ 
822820                # controlnet 
823821                controlnet_block_samples  =  self .controlnet (
824822                    hidden_states = latents ,
0 commit comments