Skip to content

Commit f28a8c2

Browse files
xduzhangjiayuyiyixuxuasomoza
authored
fix from_transformer() with extra conditioning channels (#9364)
* fix from_transformer() with extra conditioning channels * style fix --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Álvaro Somoza <[email protected]>
1 parent 2c6a6c9 commit f28a8c2

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/diffusers/models/controlnet_sd3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,12 @@ def _set_gradient_checkpointing(self, module, value=False):
242242
module.gradient_checkpointing = value
243243

244244
@classmethod
245-
def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True):
245+
def from_transformer(
246+
cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True
247+
):
246248
config = transformer.config
247249
config["num_layers"] = num_layers or config.num_layers
250+
config["extra_conditioning_channels"] = num_extra_conditioning_channels
248251
controlnet = cls(**config)
249252

250253
if load_weights_from_transformer:

0 commit comments

Comments
 (0)