@@ -114,6 +114,7 @@ def get_down_block_adapter(
114
114
cross_attention_dim : Optional [int ] = 1024 ,
115
115
add_downsample : bool = True ,
116
116
upcast_attention : Optional [bool ] = False ,
117
+ use_linear_projection : Optional [bool ] = True ,
117
118
):
118
119
num_layers = 2 # only support sd + sdxl
119
120
@@ -152,7 +153,7 @@ def get_down_block_adapter(
152
153
in_channels = ctrl_out_channels ,
153
154
num_layers = transformer_layers_per_block [i ],
154
155
cross_attention_dim = cross_attention_dim ,
155
- use_linear_projection = True ,
156
+ use_linear_projection = use_linear_projection ,
156
157
upcast_attention = upcast_attention ,
157
158
norm_num_groups = find_largest_factor (ctrl_out_channels , max_factor = max_norm_num_groups ),
158
159
)
@@ -200,6 +201,7 @@ def get_mid_block_adapter(
200
201
num_attention_heads : Optional [int ] = 1 ,
201
202
cross_attention_dim : Optional [int ] = 1024 ,
202
203
upcast_attention : bool = False ,
204
+ use_linear_projection : bool = True ,
203
205
):
204
206
# Before the midblock application, information is concatted from base to control.
205
207
# Concat doesn't require change in number of channels
@@ -214,7 +216,7 @@ def get_mid_block_adapter(
214
216
resnet_groups = find_largest_factor (gcd (ctrl_channels , ctrl_channels + base_channels ), max_norm_num_groups ),
215
217
cross_attention_dim = cross_attention_dim ,
216
218
num_attention_heads = num_attention_heads ,
217
- use_linear_projection = True ,
219
+ use_linear_projection = use_linear_projection ,
218
220
upcast_attention = upcast_attention ,
219
221
)
220
222
@@ -308,6 +310,7 @@ def __init__(
308
310
transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
309
311
upcast_attention : bool = True ,
310
312
max_norm_num_groups : int = 32 ,
313
+ use_linear_projection : bool = True ,
311
314
):
312
315
super ().__init__ ()
313
316
@@ -381,6 +384,7 @@ def __init__(
381
384
cross_attention_dim = cross_attention_dim [i ],
382
385
add_downsample = not is_final_block ,
383
386
upcast_attention = upcast_attention ,
387
+ use_linear_projection = use_linear_projection ,
384
388
)
385
389
)
386
390
@@ -393,6 +397,7 @@ def __init__(
393
397
num_attention_heads = num_attention_heads [- 1 ],
394
398
cross_attention_dim = cross_attention_dim [- 1 ],
395
399
upcast_attention = upcast_attention ,
400
+ use_linear_projection = use_linear_projection ,
396
401
)
397
402
398
403
# up
@@ -489,6 +494,7 @@ def from_unet(
489
494
transformer_layers_per_block = unet .config .transformer_layers_per_block ,
490
495
upcast_attention = unet .config .upcast_attention ,
491
496
max_norm_num_groups = unet .config .norm_num_groups ,
497
+ use_linear_projection = unet .config .use_linear_projection ,
492
498
)
493
499
494
500
# ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
@@ -538,6 +544,7 @@ def __init__(
538
544
addition_embed_type : Optional [str ] = None ,
539
545
addition_time_embed_dim : Optional [int ] = None ,
540
546
upcast_attention : bool = True ,
547
+ use_linear_projection : bool = True ,
541
548
time_cond_proj_dim : Optional [int ] = None ,
542
549
projection_class_embeddings_input_dim : Optional [int ] = None ,
543
550
# additional controlnet configs
@@ -595,7 +602,12 @@ def __init__(
595
602
time_embed_dim ,
596
603
cond_proj_dim = time_cond_proj_dim ,
597
604
)
598
- self .ctrl_time_embedding = TimestepEmbedding (in_channels = time_embed_input_dim , time_embed_dim = time_embed_dim )
605
+ if ctrl_learn_time_embedding :
606
+ self .ctrl_time_embedding = TimestepEmbedding (
607
+ in_channels = time_embed_input_dim , time_embed_dim = time_embed_dim
608
+ )
609
+ else :
610
+ self .ctrl_time_embedding = None
599
611
600
612
if addition_embed_type is None :
601
613
self .base_add_time_proj = None
@@ -632,6 +644,7 @@ def __init__(
632
644
cross_attention_dim = cross_attention_dim [i ],
633
645
add_downsample = not is_final_block ,
634
646
upcast_attention = upcast_attention ,
647
+ use_linear_projection = use_linear_projection ,
635
648
)
636
649
)
637
650
@@ -647,6 +660,7 @@ def __init__(
647
660
ctrl_num_attention_heads = ctrl_num_attention_heads [- 1 ],
648
661
cross_attention_dim = cross_attention_dim [- 1 ],
649
662
upcast_attention = upcast_attention ,
663
+ use_linear_projection = use_linear_projection ,
650
664
)
651
665
652
666
# # Create up blocks
@@ -690,6 +704,7 @@ def __init__(
690
704
add_upsample = not is_final_block ,
691
705
upcast_attention = upcast_attention ,
692
706
norm_num_groups = norm_num_groups ,
707
+ use_linear_projection = use_linear_projection ,
693
708
)
694
709
)
695
710
@@ -754,6 +769,7 @@ def from_unet(
754
769
"addition_embed_type" ,
755
770
"addition_time_embed_dim" ,
756
771
"upcast_attention" ,
772
+ "use_linear_projection" ,
757
773
"time_cond_proj_dim" ,
758
774
"projection_class_embeddings_input_dim" ,
759
775
]
@@ -1219,6 +1235,7 @@ def __init__(
1219
1235
cross_attention_dim : Optional [int ] = 1024 ,
1220
1236
add_downsample : bool = True ,
1221
1237
upcast_attention : Optional [bool ] = False ,
1238
+ use_linear_projection : Optional [bool ] = True ,
1222
1239
):
1223
1240
super ().__init__ ()
1224
1241
base_resnets = []
@@ -1270,7 +1287,7 @@ def __init__(
1270
1287
in_channels = base_out_channels ,
1271
1288
num_layers = transformer_layers_per_block [i ],
1272
1289
cross_attention_dim = cross_attention_dim ,
1273
- use_linear_projection = True ,
1290
+ use_linear_projection = use_linear_projection ,
1274
1291
upcast_attention = upcast_attention ,
1275
1292
norm_num_groups = norm_num_groups ,
1276
1293
)
@@ -1282,7 +1299,7 @@ def __init__(
1282
1299
in_channels = ctrl_out_channels ,
1283
1300
num_layers = transformer_layers_per_block [i ],
1284
1301
cross_attention_dim = cross_attention_dim ,
1285
- use_linear_projection = True ,
1302
+ use_linear_projection = use_linear_projection ,
1286
1303
upcast_attention = upcast_attention ,
1287
1304
norm_num_groups = find_largest_factor (ctrl_out_channels , max_factor = ctrl_max_norm_num_groups ),
1288
1305
)
@@ -1342,13 +1359,15 @@ def get_first_cross_attention(block):
1342
1359
ctrl_num_attention_heads = get_first_cross_attention (ctrl_downblock ).heads
1343
1360
cross_attention_dim = get_first_cross_attention (base_downblock ).cross_attention_dim
1344
1361
upcast_attention = get_first_cross_attention (base_downblock ).upcast_attention
1362
+ use_linear_projection = base_downblock .attentions [0 ].use_linear_projection
1345
1363
else :
1346
1364
has_crossattn = False
1347
1365
transformer_layers_per_block = None
1348
1366
base_num_attention_heads = None
1349
1367
ctrl_num_attention_heads = None
1350
1368
cross_attention_dim = None
1351
1369
upcast_attention = None
1370
+ use_linear_projection = None
1352
1371
add_downsample = base_downblock .downsamplers is not None
1353
1372
1354
1373
# create model
@@ -1367,6 +1386,7 @@ def get_first_cross_attention(block):
1367
1386
cross_attention_dim = cross_attention_dim ,
1368
1387
add_downsample = add_downsample ,
1369
1388
upcast_attention = upcast_attention ,
1389
+ use_linear_projection = use_linear_projection ,
1370
1390
)
1371
1391
1372
1392
# # load weights
@@ -1527,6 +1547,7 @@ def __init__(
1527
1547
ctrl_num_attention_heads : Optional [int ] = 1 ,
1528
1548
cross_attention_dim : Optional [int ] = 1024 ,
1529
1549
upcast_attention : bool = False ,
1550
+ use_linear_projection : Optional [bool ] = True ,
1530
1551
):
1531
1552
super ().__init__ ()
1532
1553
@@ -1541,7 +1562,7 @@ def __init__(
1541
1562
resnet_groups = norm_num_groups ,
1542
1563
cross_attention_dim = cross_attention_dim ,
1543
1564
num_attention_heads = base_num_attention_heads ,
1544
- use_linear_projection = True ,
1565
+ use_linear_projection = use_linear_projection ,
1545
1566
upcast_attention = upcast_attention ,
1546
1567
)
1547
1568
@@ -1556,7 +1577,7 @@ def __init__(
1556
1577
),
1557
1578
cross_attention_dim = cross_attention_dim ,
1558
1579
num_attention_heads = ctrl_num_attention_heads ,
1559
- use_linear_projection = True ,
1580
+ use_linear_projection = use_linear_projection ,
1560
1581
upcast_attention = upcast_attention ,
1561
1582
)
1562
1583
@@ -1590,6 +1611,7 @@ def get_first_cross_attention(midblock):
1590
1611
ctrl_num_attention_heads = get_first_cross_attention (ctrl_midblock ).heads
1591
1612
cross_attention_dim = get_first_cross_attention (base_midblock ).cross_attention_dim
1592
1613
upcast_attention = get_first_cross_attention (base_midblock ).upcast_attention
1614
+ use_linear_projection = base_midblock .attentions [0 ].use_linear_projection
1593
1615
1594
1616
# create model
1595
1617
model = cls (
@@ -1603,6 +1625,7 @@ def get_first_cross_attention(midblock):
1603
1625
ctrl_num_attention_heads = ctrl_num_attention_heads ,
1604
1626
cross_attention_dim = cross_attention_dim ,
1605
1627
upcast_attention = upcast_attention ,
1628
+ use_linear_projection = use_linear_projection ,
1606
1629
)
1607
1630
1608
1631
# load weights
@@ -1677,6 +1700,7 @@ def __init__(
1677
1700
cross_attention_dim : int = 1024 ,
1678
1701
add_upsample : bool = True ,
1679
1702
upcast_attention : bool = False ,
1703
+ use_linear_projection : Optional [bool ] = True ,
1680
1704
):
1681
1705
super ().__init__ ()
1682
1706
resnets = []
@@ -1714,7 +1738,7 @@ def __init__(
1714
1738
in_channels = out_channels ,
1715
1739
num_layers = transformer_layers_per_block [i ],
1716
1740
cross_attention_dim = cross_attention_dim ,
1717
- use_linear_projection = True ,
1741
+ use_linear_projection = use_linear_projection ,
1718
1742
upcast_attention = upcast_attention ,
1719
1743
norm_num_groups = norm_num_groups ,
1720
1744
)
@@ -1753,12 +1777,14 @@ def get_first_cross_attention(block):
1753
1777
num_attention_heads = get_first_cross_attention (base_upblock ).heads
1754
1778
cross_attention_dim = get_first_cross_attention (base_upblock ).cross_attention_dim
1755
1779
upcast_attention = get_first_cross_attention (base_upblock ).upcast_attention
1780
+ use_linear_projection = base_upblock .attentions [0 ].use_linear_projection
1756
1781
else :
1757
1782
has_crossattn = False
1758
1783
transformer_layers_per_block = None
1759
1784
num_attention_heads = None
1760
1785
cross_attention_dim = None
1761
1786
upcast_attention = None
1787
+ use_linear_projection = None
1762
1788
add_upsample = base_upblock .upsamplers is not None
1763
1789
1764
1790
# create model
@@ -1776,6 +1802,7 @@ def get_first_cross_attention(block):
1776
1802
cross_attention_dim = cross_attention_dim ,
1777
1803
add_upsample = add_upsample ,
1778
1804
upcast_attention = upcast_attention ,
1805
+ use_linear_projection = use_linear_projection ,
1779
1806
)
1780
1807
1781
1808
# load weights
0 commit comments