Skip to content

Commit bbbc4c0

Browse files
authored
Merge branch 'main' into speedup-model-loading
2 parents a6ee660 + 941b7fc commit bbbc4c0

File tree

4 files changed

+232
-103
lines changed

4 files changed

+232
-103
lines changed

src/diffusers/models/controlnets/controlnet_union.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def forward(
752752
condition = self.controlnet_cond_embedding(cond)
753753
feat_seq = torch.mean(condition, dim=(2, 3))
754754
feat_seq = feat_seq + self.task_embedding[control_idx]
755-
if from_multi:
755+
if from_multi or len(control_type_idx) == 1:
756756
inputs.append(feat_seq.unsqueeze(1))
757757
condition_list.append(condition)
758758
else:
@@ -772,7 +772,7 @@ def forward(
772772
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
773773
alpha = self.spatial_ch_projs(x[:, idx])
774774
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
775-
if from_multi:
775+
if from_multi or len(control_type_idx) == 1:
776776
controlnet_cond_fuser += condition + alpha
777777
else:
778778
controlnet_cond_fuser += condition + alpha * scale
@@ -819,11 +819,11 @@ def forward(
819819
# 6. scaling
820820
if guess_mode and not self.config.global_pool_conditions:
821821
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
822-
if from_multi:
822+
if from_multi or len(control_type_idx) == 1:
823823
scales = scales * conditioning_scale[0]
824824
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
825825
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
826-
elif from_multi:
826+
elif from_multi or len(control_type_idx) == 1:
827827
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
828828
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
829829

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,15 @@ def __call__(
187187
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
188188

189189
# 4. Prepare for GQA
190-
query_idx = torch.tensor(query.size(3), device=query.device)
191-
key_idx = torch.tensor(key.size(3), device=key.device)
192-
value_idx = torch.tensor(value.size(3), device=value.device)
190+
if torch.onnx.is_in_onnx_export():
191+
query_idx = torch.tensor(query.size(3), device=query.device)
192+
key_idx = torch.tensor(key.size(3), device=key.device)
193+
value_idx = torch.tensor(value.size(3), device=value.device)
194+
195+
else:
196+
query_idx = query.size(3)
197+
key_idx = key.size(3)
198+
value_idx = value.size(3)
193199
key = key.repeat_interleave(query_idx // key_idx, dim=3)
194200
value = value.repeat_interleave(query_idx // value_idx, dim=3)
195201

0 commit comments

Comments
 (0)