Skip to content

Commit cc33cd3

Browse files
Experimental lyrics strength for ACE. (#7984)
1 parent b998059 commit cc33cd3

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

comfy/ldm/ace/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ def encode(
273273
speaker_embeds: Optional[torch.FloatTensor] = None,
274274
lyric_token_idx: Optional[torch.LongTensor] = None,
275275
lyric_mask: Optional[torch.LongTensor] = None,
276+
lyrics_strength=1.0,
276277
):
277278

278279
bs = encoder_text_hidden_states.shape[0]
@@ -291,6 +292,8 @@ def encode(
291292
out_dtype=encoder_text_hidden_states.dtype,
292293
)
293294

295+
encoder_lyric_hidden_states *= lyrics_strength
296+
294297
encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
295298

296299
encoder_hidden_mask = None
@@ -310,7 +313,6 @@ def decode(
310313
output_length: int = 0,
311314
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
312315
controlnet_scale: Union[float, torch.Tensor] = 1.0,
313-
return_dict: bool = True,
314316
):
315317
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
316318
temb = self.t_block(embedded_timestep)
@@ -353,6 +355,7 @@ def forward(
353355
lyric_mask: Optional[torch.LongTensor] = None,
354356
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
355357
controlnet_scale: Union[float, torch.Tensor] = 1.0,
358+
lyrics_strength=1.0,
356359
**kwargs
357360
):
358361
hidden_states = x
@@ -363,6 +366,7 @@ def forward(
363366
speaker_embeds=speaker_embeds,
364367
lyric_token_idx=lyric_token_idx,
365368
lyric_mask=lyric_mask,
369+
lyrics_strength=lyrics_strength,
366370
)
367371

368372
output_length = hidden_states.shape[-1]

comfy/model_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,4 +1139,5 @@ def extra_conds(self, **kwargs):
11391139
if cross_attn is not None:
11401140
out['lyric_token_idx'] = comfy.conds.CONDRegular(conditioning_lyrics)
11411141
out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
1142+
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
11421143
return out

comfy_extras/nodes_ace.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import comfy.model_management
3-
3+
import node_helpers
44

55
class TextEncodeAceStepAudio:
66
@classmethod
@@ -9,15 +9,18 @@ def INPUT_TYPES(s):
99
"clip": ("CLIP", ),
1010
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
1111
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
12+
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
1213
}}
1314
RETURN_TYPES = ("CONDITIONING",)
1415
FUNCTION = "encode"
1516

1617
CATEGORY = "conditioning"
1718

18-
def encode(self, clip, tags, lyrics):
19+
def encode(self, clip, tags, lyrics, lyrics_strength):
1920
tokens = clip.tokenize(tags, lyrics=lyrics)
20-
return (clip.encode_from_tokens_scheduled(tokens), )
21+
conditioning = clip.encode_from_tokens_scheduled(tokens)
22+
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
23+
return (conditioning, )
2124

2225

2326
class EmptyAceStepLatentAudio:

0 commit comments

Comments
 (0)