Skip to content

Commit 32dd91c

Browse files
committed
Improvements to ACE-Steps 1.5 text encoding
1 parent ef73070 commit 32dd91c

File tree

1 file changed

+44
-12
lines changed

1 file changed

+44
-12
lines changed

comfy/text_encoders/ace15.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from comfy import sd1_clip
44
import torch
55
import math
6+
import yaml
67
import comfy.utils
78

89

@@ -127,26 +128,57 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
127128
def __init__(self, embedding_directory=None, tokenizer_data={}):
128129
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer)
129130

131+
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
132+
user_metas = {
133+
k: kwargs.pop(k)
134+
for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
135+
if k in kwargs
136+
}
137+
timesignature = user_metas.get("timesignature")
138+
if isinstance(timesignature, str) and timesignature.endswith("/4"):
139+
user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
140+
user_metas = {
141+
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
142+
for k, v in user_metas.items()
143+
if v not in {"unspecified", None}
144+
}
145+
if len(user_metas):
146+
meta_yaml = yaml.dump(user_metas, allow_unicode=True, sort_keys=True).strip()
147+
else:
148+
meta_yaml = ""
149+
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
150+
151+
def _metas_to_cap(self, **kwargs) -> str:
152+
use_keys = ("bpm", "duration", "keyscale", "timesignature")
153+
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
154+
duration = user_metas["duration"]
155+
if duration == "N/A":
156+
user_metas["duration"] = "30 seconds"
157+
elif isinstance(duration, (str, int, float)):
158+
user_metas["duration"] = f"{math.ceil(float(duration))} seconds"
159+
else:
160+
raise TypeError("Unexpected type for duration key, must be str, int or float")
161+
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
162+
130163
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
131164
out = {}
132165
lyrics = kwargs.get("lyrics", "")
133-
bpm = kwargs.get("bpm", 120)
134166
duration = kwargs.get("duration", 120)
135-
keyscale = kwargs.get("keyscale", "C major")
136-
timesignature = kwargs.get("timesignature", 2)
137-
language = kwargs.get("language", "en")
167+
language = kwargs.get("language")
138168
seed = kwargs.get("seed", 0)
139-
140169
duration = math.ceil(duration)
141-
meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature)
142-
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n<think>\n{}\n</think>\n\n<|im_end|>\n"
170+
kwargs["duration"] = duration
171+
172+
cot_text = self._metas_to_cot(caption = text, **kwargs)
173+
meta_cap = self._metas_to_cap(**kwargs)
174+
175+
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
143176

144-
meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration)
145-
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True)
146-
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True)
177+
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
178+
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "<think>\n</think>"), disable_weights=True)
147179

148-
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs)
149-
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
180+
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
181+
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
150182
out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed}
151183
return out
152184

0 commit comments

Comments
 (0)