Skip to content

Commit 7be937e

Browse files
committed
Use conditioning_embedding_channels instead of is_xlabs_controlnet in config
1 parent b8c9496 commit 7be937e

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

src/diffusers/models/controlnet_flux.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
guidance_embeds: bool = False,
5656
axes_dims_rope: List[int] = [16, 56, 56],
5757
num_mode: int = None,
58-
is_xlabs_controlnet: bool = False,
58+
conditioning_embedding_channels: int = None,
5959
):
6060
super().__init__()
6161
self.out_channels = in_channels
@@ -107,13 +107,14 @@ def __init__(
107107
if self.union:
108108
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
109109

110-
if self.is_xlabs_controlnet:
110+
if conditioning_embedding_channels is not None:
111111
self.input_hint_block = ControlNetConditioningEmbedding(
112-
conditioning_embedding_channels=16,
112+
conditioning_embedding_channels=conditioning_embedding_channels,
113113
block_out_channels=(16,16,16,16)
114114
)
115115
self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
116116
else:
117+
self.input_hint_block = None
117118
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
118119

119120
self.gradient_checkpointing = False
@@ -277,7 +278,7 @@ def forward(
277278
)
278279
hidden_states = self.x_embedder(hidden_states)
279280

280-
if self.is_xlabs_controlnet:
281+
if self.input_hint_block is not None:
281282
controlnet_cond = self.input_hint_block(controlnet_cond)
282283
batch_size, channels, height_pw, width_pw = controlnet_cond.shape
283284
height = height_pw // self.config.patch_size

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def __call__(
752752
)
753753
height, width = control_image.shape[-2:]
754754

755-
if not self.controlnet.is_xlabs_controlnet:
755+
if self.controlnet.input_hint_block is None:
756756
# vae encode
757757
control_image = self.vae.encode(control_image).latent_dist.sample()
758758
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor

0 commit comments

Comments
 (0)