@@ -754,19 +754,22 @@ def __call__(
754754 )
755755 height , width = control_image .shape [- 2 :]
756756
757- # vae encode
758- control_image = self .vae .encode (control_image ).latent_dist .sample ()
759- control_image = (control_image - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
760-
761- # pack
762- height_control_image , width_control_image = control_image .shape [2 :]
763- control_image = self ._pack_latents (
764- control_image ,
765- batch_size * num_images_per_prompt ,
766- num_channels_latents ,
767- height_control_image ,
768- width_control_image ,
769- )
757+ # xlab controlnet has a input_hint_block and instantx controlnet does not
758+ controlnet_blocks_repeat = False if self .controlnet .input_hint_block is None else True
759+ if self .controlnet .input_hint_block is None :
760+ # vae encode
761+ control_image = self .vae .encode (control_image ).latent_dist .sample ()
762+ control_image = (control_image - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
763+
764+ # pack
765+ height_control_image , width_control_image = control_image .shape [2 :]
766+ control_image = self ._pack_latents (
767+ control_image ,
768+ batch_size * num_images_per_prompt ,
769+ num_channels_latents ,
770+ height_control_image ,
771+ width_control_image ,
772+ )
770773
771774 # Here we ensure that `control_mode` has the same length as the control_image.
772775 if control_mode is not None :
@@ -777,8 +780,9 @@ def __call__(
777780
778781 elif isinstance (self .controlnet , FluxMultiControlNetModel ):
779782 control_images = []
780-
781- for control_image_ in control_image :
783+ # xlab controlnet has a input_hint_block and instantx controlnet does not
784+ controlnet_blocks_repeat = False if self .controlnet .nets [0 ].input_hint_block is None else True
785+ for i , control_image_ in enumerate (control_image ):
782786 control_image_ = self .prepare_image (
783787 image = control_image_ ,
784788 width = width ,
@@ -790,20 +794,20 @@ def __call__(
790794 )
791795 height , width = control_image_ .shape [- 2 :]
792796
793- # vae encode
794- control_image_ = self . vae . encode ( control_image_ ). latent_dist . sample ()
795- control_image_ = ( control_image_ - self .vae .config . shift_factor ) * self . vae . config . scaling_factor
796-
797- # pack
798- height_control_image , width_control_image = control_image_ . shape [ 2 :]
799- control_image_ = self . _pack_latents (
800- control_image_ ,
801- batch_size * num_images_per_prompt ,
802- num_channels_latents ,
803- height_control_image ,
804- width_control_image ,
805- )
806-
797+ if self . controlnet . nets [ 0 ]. input_hint_block is None :
798+ # vae encode
799+ control_image_ = self .vae .encode ( control_image_ ). latent_dist . sample ()
800+ control_image_ = ( control_image_ - self . vae . config . shift_factor ) * self . vae . config . scaling_factor
801+
802+ # pack
803+ height_control_image , width_control_image = control_image_ . shape [ 2 :]
804+ control_image_ = self . _pack_latents (
805+ control_image_ ,
806+ batch_size * num_images_per_prompt ,
807+ num_channels_latents ,
808+ height_control_image ,
809+ width_control_image ,
810+ )
807811 control_images .append (control_image_ )
808812
809813 control_image = control_images
@@ -927,6 +931,7 @@ def __call__(
927931 img_ids = latent_image_ids ,
928932 joint_attention_kwargs = self .joint_attention_kwargs ,
929933 return_dict = False ,
934+ controlnet_blocks_repeat = controlnet_blocks_repeat ,
930935 )[0 ]
931936
932937 # compute the previous noisy sample x_t -> x_t-1
0 commit comments