@@ -932,19 +932,22 @@ def __call__(
932932 )
933933 height , width = control_image .shape [- 2 :]
934934
935- # vae encode
936- control_image = self .vae .encode (control_image ).latent_dist .sample ()
937- control_image = (control_image - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
938-
939- # pack
940- height_control_image , width_control_image = control_image .shape [2 :]
941- control_image = self ._pack_latents (
942- control_image ,
943- batch_size * num_images_per_prompt ,
944- num_channels_latents ,
945- height_control_image ,
946- width_control_image ,
947- )
935+ # xlab controlnet has a input_hint_block and instantx controlnet does not
936+ controlnet_blocks_repeat = False if self .controlnet .input_hint_block is None else True
937+ if self .controlnet .input_hint_block is None :
938+ # vae encode
939+ control_image = self .vae .encode (control_image ).latent_dist .sample ()
940+ control_image = (control_image - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
941+
942+ # pack
943+ height_control_image , width_control_image = control_image .shape [2 :]
944+ control_image = self ._pack_latents (
945+ control_image ,
946+ batch_size * num_images_per_prompt ,
947+ num_channels_latents ,
948+ height_control_image ,
949+ width_control_image ,
950+ )
948951
949952 # set control mode
950953 if control_mode is not None :
@@ -954,7 +957,9 @@ def __call__(
954957 elif isinstance (self .controlnet , FluxMultiControlNetModel ):
955958 control_images = []
956959
957- for control_image_ in control_image :
960+ # xlab controlnet has a input_hint_block and instantx controlnet does not
961+ controlnet_blocks_repeat = False if self .controlnet .nets [0 ].input_hint_block is None else True
962+ for i , control_image_ in enumerate (control_image ):
958963 control_image_ = self .prepare_image (
959964 image = control_image_ ,
960965 width = width ,
@@ -966,19 +971,20 @@ def __call__(
966971 )
967972 height , width = control_image_ .shape [- 2 :]
968973
969- # vae encode
970- control_image_ = self .vae .encode (control_image_ ).latent_dist .sample ()
971- control_image_ = (control_image_ - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
972-
973- # pack
974- height_control_image , width_control_image = control_image_ .shape [2 :]
975- control_image_ = self ._pack_latents (
976- control_image_ ,
977- batch_size * num_images_per_prompt ,
978- num_channels_latents ,
979- height_control_image ,
980- width_control_image ,
981- )
974+ if self .controlnet .nets [0 ].input_hint_block is None :
975+ # vae encode
976+ control_image_ = self .vae .encode (control_image_ ).latent_dist .sample ()
977+ control_image_ = (control_image_ - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
978+
979+ # pack
980+ height_control_image , width_control_image = control_image_ .shape [2 :]
981+ control_image_ = self ._pack_latents (
982+ control_image_ ,
983+ batch_size * num_images_per_prompt ,
984+ num_channels_latents ,
985+ height_control_image ,
986+ width_control_image ,
987+ )
982988
983989 control_images .append (control_image_ )
984990
@@ -1129,6 +1135,7 @@ def __call__(
11291135 img_ids = latent_image_ids ,
11301136 joint_attention_kwargs = self .joint_attention_kwargs ,
11311137 return_dict = False ,
1138+ controlnet_blocks_repeat = controlnet_blocks_repeat ,
11321139 )[0 ]
11331140
11341141 # compute the previous noisy sample x_t -> x_t-1
0 commit comments