|
3 | 3 | from comfy import sd1_clip |
4 | 4 | import torch |
5 | 5 | import math |
| 6 | +import yaml |
6 | 7 | import comfy.utils |
7 | 8 |
|
8 | 9 |
|
@@ -127,26 +128,57 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer): |
127 | 128 | def __init__(self, embedding_directory=None, tokenizer_data={}): |
128 | 129 | super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer) |
129 | 130 |
|
| 131 | + def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str: |
| 132 | + user_metas = { |
| 133 | + k: kwargs.pop(k) |
| 134 | + for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption") |
| 135 | + if k in kwargs |
| 136 | + } |
| 137 | + timesignature = user_metas.get("timesignature") |
| 138 | + if isinstance(timesignature, str) and timesignature.endswith("/4"): |
| 139 | + user_metas["timesignature"] = timesignature.rsplit("/", 1)[0] |
| 140 | + user_metas = { |
| 141 | + k: v if not isinstance(v, str) or not v.isdigit() else int(v) |
| 142 | + for k, v in user_metas.items() |
| 143 | + if v not in {"unspecified", None} |
| 144 | + } |
| 145 | + if len(user_metas): |
| 146 | + meta_yaml = yaml.dump(user_metas, allow_unicode=True, sort_keys=True).strip() |
| 147 | + else: |
| 148 | + meta_yaml = "" |
| 149 | + return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml |
| 150 | + |
| 151 | + def _metas_to_cap(self, **kwargs) -> str: |
| 152 | + use_keys = ("bpm", "duration", "keyscale", "timesignature") |
| 153 | + user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys } |
| 154 | + duration = user_metas["duration"] |
| 155 | + if duration == "N/A": |
| 156 | + user_metas["duration"] = "30 seconds" |
| 157 | + elif isinstance(duration, (str, int, float)): |
| 158 | + user_metas["duration"] = f"{math.ceil(float(duration))} seconds" |
| 159 | + else: |
| 160 | + raise TypeError("Unexpected type for duration key, must be str, int or float") |
| 161 | + return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys) |
| 162 | + |
130 | 163 | def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): |
131 | 164 | out = {} |
132 | 165 | lyrics = kwargs.get("lyrics", "") |
133 | | - bpm = kwargs.get("bpm", 120) |
134 | 166 | duration = kwargs.get("duration", 120) |
135 | | - keyscale = kwargs.get("keyscale", "C major") |
136 | | - timesignature = kwargs.get("timesignature", 2) |
137 | | - language = kwargs.get("language", "en") |
| 167 | + language = kwargs.get("language") |
138 | 168 | seed = kwargs.get("seed", 0) |
139 | | - |
140 | 169 | duration = math.ceil(duration) |
141 | | - meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature) |
142 | | - lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n<think>\n{}\n</think>\n\n<|im_end|>\n" |
| 170 | + kwargs["duration"] = duration |
| 171 | + |
| 172 | + cot_text = self._metas_to_cot(caption = text, **kwargs) |
| 173 | + meta_cap = self._metas_to_cap(**kwargs) |
| 174 | + |
| 175 | + lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n" |
143 | 176 |
|
144 | | - meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration) |
145 | | - out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True) |
146 | | - out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True) |
| 177 | + out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True) |
| 178 | + out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "<think>\n</think>"), disable_weights=True) |
147 | 179 |
|
148 | | - out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs) |
149 | | - out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs) |
| 180 | + out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs) |
| 181 | + out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs) |
150 | 182 | out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed} |
151 | 183 | return out |
152 | 184 |
|
|
0 commit comments