Skip to content

Commit 16170c6

Browse files
authored
add sd1.5 compatibility to controlnet-xs and fix unused_parameters error during training (#8606)
* add sd1.5 compatibility to controlnet-xs * set use_linear_projection by base_block * refine code style
1 parent 4408047 commit 16170c6

File tree

1 file changed

+35
-8
lines changed

1 file changed

+35
-8
lines changed

src/diffusers/models/controlnet_xs.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def get_down_block_adapter(
114114
cross_attention_dim: Optional[int] = 1024,
115115
add_downsample: bool = True,
116116
upcast_attention: Optional[bool] = False,
117+
use_linear_projection: Optional[bool] = True,
117118
):
118119
num_layers = 2 # only support sd + sdxl
119120

@@ -152,7 +153,7 @@ def get_down_block_adapter(
152153
in_channels=ctrl_out_channels,
153154
num_layers=transformer_layers_per_block[i],
154155
cross_attention_dim=cross_attention_dim,
155-
use_linear_projection=True,
156+
use_linear_projection=use_linear_projection,
156157
upcast_attention=upcast_attention,
157158
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
158159
)
@@ -200,6 +201,7 @@ def get_mid_block_adapter(
200201
num_attention_heads: Optional[int] = 1,
201202
cross_attention_dim: Optional[int] = 1024,
202203
upcast_attention: bool = False,
204+
use_linear_projection: bool = True,
203205
):
204206
# Before the midblock application, information is concatted from base to control.
205207
# Concat doesn't require change in number of channels
@@ -214,7 +216,7 @@ def get_mid_block_adapter(
214216
resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
215217
cross_attention_dim=cross_attention_dim,
216218
num_attention_heads=num_attention_heads,
217-
use_linear_projection=True,
219+
use_linear_projection=use_linear_projection,
218220
upcast_attention=upcast_attention,
219221
)
220222

@@ -308,6 +310,7 @@ def __init__(
308310
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
309311
upcast_attention: bool = True,
310312
max_norm_num_groups: int = 32,
313+
use_linear_projection: bool = True,
311314
):
312315
super().__init__()
313316

@@ -381,6 +384,7 @@ def __init__(
381384
cross_attention_dim=cross_attention_dim[i],
382385
add_downsample=not is_final_block,
383386
upcast_attention=upcast_attention,
387+
use_linear_projection=use_linear_projection,
384388
)
385389
)
386390

@@ -393,6 +397,7 @@ def __init__(
393397
num_attention_heads=num_attention_heads[-1],
394398
cross_attention_dim=cross_attention_dim[-1],
395399
upcast_attention=upcast_attention,
400+
use_linear_projection=use_linear_projection,
396401
)
397402

398403
# up
@@ -489,6 +494,7 @@ def from_unet(
489494
transformer_layers_per_block=unet.config.transformer_layers_per_block,
490495
upcast_attention=unet.config.upcast_attention,
491496
max_norm_num_groups=unet.config.norm_num_groups,
497+
use_linear_projection=unet.config.use_linear_projection,
492498
)
493499

494500
# ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
@@ -538,6 +544,7 @@ def __init__(
538544
addition_embed_type: Optional[str] = None,
539545
addition_time_embed_dim: Optional[int] = None,
540546
upcast_attention: bool = True,
547+
use_linear_projection: bool = True,
541548
time_cond_proj_dim: Optional[int] = None,
542549
projection_class_embeddings_input_dim: Optional[int] = None,
543550
# additional controlnet configs
@@ -595,7 +602,12 @@ def __init__(
595602
time_embed_dim,
596603
cond_proj_dim=time_cond_proj_dim,
597604
)
598-
self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim)
605+
if ctrl_learn_time_embedding:
606+
self.ctrl_time_embedding = TimestepEmbedding(
607+
in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim
608+
)
609+
else:
610+
self.ctrl_time_embedding = None
599611

600612
if addition_embed_type is None:
601613
self.base_add_time_proj = None
@@ -632,6 +644,7 @@ def __init__(
632644
cross_attention_dim=cross_attention_dim[i],
633645
add_downsample=not is_final_block,
634646
upcast_attention=upcast_attention,
647+
use_linear_projection=use_linear_projection,
635648
)
636649
)
637650

@@ -647,6 +660,7 @@ def __init__(
647660
ctrl_num_attention_heads=ctrl_num_attention_heads[-1],
648661
cross_attention_dim=cross_attention_dim[-1],
649662
upcast_attention=upcast_attention,
663+
use_linear_projection=use_linear_projection,
650664
)
651665

652666
# # Create up blocks
@@ -690,6 +704,7 @@ def __init__(
690704
add_upsample=not is_final_block,
691705
upcast_attention=upcast_attention,
692706
norm_num_groups=norm_num_groups,
707+
use_linear_projection=use_linear_projection,
693708
)
694709
)
695710

@@ -754,6 +769,7 @@ def from_unet(
754769
"addition_embed_type",
755770
"addition_time_embed_dim",
756771
"upcast_attention",
772+
"use_linear_projection",
757773
"time_cond_proj_dim",
758774
"projection_class_embeddings_input_dim",
759775
]
@@ -1219,6 +1235,7 @@ def __init__(
12191235
cross_attention_dim: Optional[int] = 1024,
12201236
add_downsample: bool = True,
12211237
upcast_attention: Optional[bool] = False,
1238+
use_linear_projection: Optional[bool] = True,
12221239
):
12231240
super().__init__()
12241241
base_resnets = []
@@ -1270,7 +1287,7 @@ def __init__(
12701287
in_channels=base_out_channels,
12711288
num_layers=transformer_layers_per_block[i],
12721289
cross_attention_dim=cross_attention_dim,
1273-
use_linear_projection=True,
1290+
use_linear_projection=use_linear_projection,
12741291
upcast_attention=upcast_attention,
12751292
norm_num_groups=norm_num_groups,
12761293
)
@@ -1282,7 +1299,7 @@ def __init__(
12821299
in_channels=ctrl_out_channels,
12831300
num_layers=transformer_layers_per_block[i],
12841301
cross_attention_dim=cross_attention_dim,
1285-
use_linear_projection=True,
1302+
use_linear_projection=use_linear_projection,
12861303
upcast_attention=upcast_attention,
12871304
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
12881305
)
@@ -1342,13 +1359,15 @@ def get_first_cross_attention(block):
13421359
ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads
13431360
cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim
13441361
upcast_attention = get_first_cross_attention(base_downblock).upcast_attention
1362+
use_linear_projection = base_downblock.attentions[0].use_linear_projection
13451363
else:
13461364
has_crossattn = False
13471365
transformer_layers_per_block = None
13481366
base_num_attention_heads = None
13491367
ctrl_num_attention_heads = None
13501368
cross_attention_dim = None
13511369
upcast_attention = None
1370+
use_linear_projection = None
13521371
add_downsample = base_downblock.downsamplers is not None
13531372

13541373
# create model
@@ -1367,6 +1386,7 @@ def get_first_cross_attention(block):
13671386
cross_attention_dim=cross_attention_dim,
13681387
add_downsample=add_downsample,
13691388
upcast_attention=upcast_attention,
1389+
use_linear_projection=use_linear_projection,
13701390
)
13711391

13721392
# # load weights
@@ -1527,6 +1547,7 @@ def __init__(
15271547
ctrl_num_attention_heads: Optional[int] = 1,
15281548
cross_attention_dim: Optional[int] = 1024,
15291549
upcast_attention: bool = False,
1550+
use_linear_projection: Optional[bool] = True,
15301551
):
15311552
super().__init__()
15321553

@@ -1541,7 +1562,7 @@ def __init__(
15411562
resnet_groups=norm_num_groups,
15421563
cross_attention_dim=cross_attention_dim,
15431564
num_attention_heads=base_num_attention_heads,
1544-
use_linear_projection=True,
1565+
use_linear_projection=use_linear_projection,
15451566
upcast_attention=upcast_attention,
15461567
)
15471568

@@ -1556,7 +1577,7 @@ def __init__(
15561577
),
15571578
cross_attention_dim=cross_attention_dim,
15581579
num_attention_heads=ctrl_num_attention_heads,
1559-
use_linear_projection=True,
1580+
use_linear_projection=use_linear_projection,
15601581
upcast_attention=upcast_attention,
15611582
)
15621583

@@ -1590,6 +1611,7 @@ def get_first_cross_attention(midblock):
15901611
ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
15911612
cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
15921613
upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
1614+
use_linear_projection = base_midblock.attentions[0].use_linear_projection
15931615

15941616
# create model
15951617
model = cls(
@@ -1603,6 +1625,7 @@ def get_first_cross_attention(midblock):
16031625
ctrl_num_attention_heads=ctrl_num_attention_heads,
16041626
cross_attention_dim=cross_attention_dim,
16051627
upcast_attention=upcast_attention,
1628+
use_linear_projection=use_linear_projection,
16061629
)
16071630

16081631
# load weights
@@ -1677,6 +1700,7 @@ def __init__(
16771700
cross_attention_dim: int = 1024,
16781701
add_upsample: bool = True,
16791702
upcast_attention: bool = False,
1703+
use_linear_projection: Optional[bool] = True,
16801704
):
16811705
super().__init__()
16821706
resnets = []
@@ -1714,7 +1738,7 @@ def __init__(
17141738
in_channels=out_channels,
17151739
num_layers=transformer_layers_per_block[i],
17161740
cross_attention_dim=cross_attention_dim,
1717-
use_linear_projection=True,
1741+
use_linear_projection=use_linear_projection,
17181742
upcast_attention=upcast_attention,
17191743
norm_num_groups=norm_num_groups,
17201744
)
@@ -1753,12 +1777,14 @@ def get_first_cross_attention(block):
17531777
num_attention_heads = get_first_cross_attention(base_upblock).heads
17541778
cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
17551779
upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
1780+
use_linear_projection = base_upblock.attentions[0].use_linear_projection
17561781
else:
17571782
has_crossattn = False
17581783
transformer_layers_per_block = None
17591784
num_attention_heads = None
17601785
cross_attention_dim = None
17611786
upcast_attention = None
1787+
use_linear_projection = None
17621788
add_upsample = base_upblock.upsamplers is not None
17631789

17641790
# create model
@@ -1776,6 +1802,7 @@ def get_first_cross_attention(block):
17761802
cross_attention_dim=cross_attention_dim,
17771803
add_upsample=add_upsample,
17781804
upcast_attention=upcast_attention,
1805+
use_linear_projection=use_linear_projection,
17791806
)
17801807

17811808
# load weights

0 commit comments

Comments
 (0)