Skip to content

Commit a510a56

Browse files
Merge branch 'master' into improve_ace15_te
2 parents 32dd91c + a50c32d commit a510a56

File tree

9 files changed

+68
-26
lines changed

9 files changed

+68
-26
lines changed

comfy/ldm/ace/ace_step15.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def forward(
183183
else:
184184
attn_bias = window_bias
185185

186-
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True)
186+
attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False)
187187
attn_output = self.o_proj(attn_output)
188188

189189
return attn_output
@@ -1035,8 +1035,7 @@ def prepare_condition(
10351035
audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847)
10361036
lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype)
10371037
else:
1038-
assert False
1039-
# TODO ?
1038+
lm_hints_5Hz, indices = self.tokenizer.tokenize(refer_audio_acoustic_hidden_states_packed)
10401039

10411040
lm_hints = self.detokenizer(lm_hints_5Hz)
10421041

comfy/ldm/modules/attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
524524

525525
@wrap_attn
526526
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
527+
if kwargs.get("low_precision_attention", True) is False:
528+
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
529+
527530
exception_fallback = False
528531
if skip_reshape:
529532
b, _, _, dim_head = q.shape

comfy/model_base.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,7 @@ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
15481548
def extra_conds(self, **kwargs):
15491549
out = super().extra_conds(**kwargs)
15501550
device = kwargs["device"]
1551+
noise = kwargs["noise"]
15511552

15521553
cross_attn = kwargs.get("cross_attn", None)
15531554
if cross_attn is not None:
@@ -1571,15 +1572,19 @@ def extra_conds(self, **kwargs):
15711572
1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
15721573
-8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
15731574
-5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
1574-
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750)
1575+
7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, noise.shape[2])
1576+
pass_audio_codes = True
15751577
else:
1576-
refer_audio = refer_audio[-1]
1577-
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
1578+
refer_audio = refer_audio[-1][:, :, :noise.shape[2]]
1579+
pass_audio_codes = False
15781580

1579-
audio_codes = kwargs.get("audio_codes", None)
1580-
if audio_codes is not None:
1581-
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
1581+
if pass_audio_codes:
1582+
audio_codes = kwargs.get("audio_codes", None)
1583+
if audio_codes is not None:
1584+
out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
1585+
refer_audio = refer_audio[:, :, :750]
15821586

1587+
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
15831588
return out
15841589

15851590
class Omnigen2(BaseModel):

comfy/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
5454
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
5555

5656
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
57+
if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower
58+
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
5759
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
5860
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
5961
else:

comfy/text_encoders/ace15.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ def sample_manual_loop_no_classes(
102102
return output_audio_codes
103103

104104

105-
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0):
106-
cfg_scale = 2.0
107-
105+
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
108106
positive = [[token for token, _ in inner_list] for inner_list in positive]
109107
negative = [[token for token, _ in inner_list] for inner_list in negative]
110108
positive = positive[0]
@@ -121,7 +119,7 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
121119
positive = [model.special_tokens["pad"]] * pos_pad + positive
122120

123121
paddings = [pos_pad, neg_pad]
124-
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
122+
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
125123

126124

127125
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
@@ -166,6 +164,14 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
166164
duration = kwargs.get("duration", 120)
167165
language = kwargs.get("language")
168166
seed = kwargs.get("seed", 0)
167+
168+
generate_audio_codes = kwargs.get("generate_audio_codes", True)
169+
cfg_scale = kwargs.get("cfg_scale", 2.0)
170+
temperature = kwargs.get("temperature", 0.85)
171+
top_p = kwargs.get("top_p", 0.9)
172+
top_k = kwargs.get("top_k", 0.0)
173+
174+
169175
duration = math.ceil(duration)
170176
kwargs["duration"] = duration
171177

@@ -179,7 +185,14 @@ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
179185

180186
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
181187
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
182-
out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed}
188+
out["lm_metadata"] = {"min_tokens": duration * 5,
189+
"seed": seed,
190+
"generate_audio_codes": generate_audio_codes,
191+
"cfg_scale": cfg_scale,
192+
"temperature": temperature,
193+
"top_p": top_p,
194+
"top_k": top_k,
195+
}
183196
return out
184197

185198

@@ -235,10 +248,14 @@ def encode_token_weights(self, token_weight_pairs):
235248
self.qwen3_06b.set_clip_options({"layer": [0]})
236249
lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics)
237250

251+
out = {"conditioning_lyrics": lyrics_embeds[:, 0]}
252+
238253
lm_metadata = token_weight_pairs["lm_metadata"]
239-
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"])
254+
if lm_metadata["generate_audio_codes"]:
255+
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
256+
out["audio_codes"] = [audio_codes]
240257

241-
return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]}
258+
return base_out, None, out
242259

243260
def set_clip_options(self, options):
244261
self.qwen3_06b.set_clip_options(options)

comfy/text_encoders/llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -651,10 +651,10 @@ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermed
651651
mask = None
652652
if attention_mask is not None:
653653
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
654-
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min)
654+
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4)
655655

656656
if seq_len > 1:
657-
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1)
657+
causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1)
658658
if mask is not None:
659659
mask += causal_mask
660660
else:

comfy_extras/nodes_ace.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,18 @@ def define_schema(cls):
4444
io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
4545
io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
4646
io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
47+
io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
48+
io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
49+
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
50+
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
51+
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
4752
],
4853
outputs=[io.Conditioning.Output()],
4954
)
5055

5156
@classmethod
52-
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput:
53-
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed)
57+
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
58+
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
5459
conditioning = clip.encode_from_tokens_scheduled(tokens)
5560
return io.NodeOutput(conditioning)
5661

@@ -100,14 +105,15 @@ def execute(cls, seconds, batch_size) -> io.NodeOutput:
100105
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
101106
return io.NodeOutput({"samples": latent, "type": "audio"})
102107

103-
class ReferenceTimbreAudio(io.ComfyNode):
108+
class ReferenceAudio(io.ComfyNode):
104109
@classmethod
105110
def define_schema(cls):
106111
return io.Schema(
107112
node_id="ReferenceTimbreAudio",
113+
display_name="Reference Audio",
108114
category="advanced/conditioning/audio",
109115
is_experimental=True,
110-
description="This node sets the reference audio for timbre (for ace step 1.5)",
116+
description="This node sets the reference audio for ace step 1.5",
111117
inputs=[
112118
io.Conditioning.Input("conditioning"),
113119
io.Latent.Input("latent", optional=True),
@@ -131,7 +137,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
131137
EmptyAceStepLatentAudio,
132138
TextEncodeAceStepAudio15,
133139
EmptyAceStep15LatentAudio,
134-
ReferenceTimbreAudio,
140+
ReferenceAudio,
135141
]
136142

137143
async def comfy_entrypoint() -> AceExtension:

comfy_extras/nodes_hunyuan3d.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,7 @@ class SaveGLB(IO.ComfyNode):
618618
def define_schema(cls):
619619
return IO.Schema(
620620
node_id="SaveGLB",
621+
display_name="Save 3D Model",
621622
search_aliases=["export 3d model", "save mesh"],
622623
category="3d",
623624
is_output_node=True,
@@ -626,8 +627,14 @@ def define_schema(cls):
626627
IO.Mesh.Input("mesh"),
627628
types=[
628629
IO.File3DGLB,
630+
IO.File3DGLTF,
631+
IO.File3DOBJ,
632+
IO.File3DFBX,
633+
IO.File3DSTL,
634+
IO.File3DUSDZ,
635+
IO.File3DAny,
629636
],
630-
tooltip="Mesh or GLB file to save",
637+
tooltip="Mesh or 3D file to save",
631638
),
632639
IO.String.Input("filename_prefix", default="mesh/ComfyUI"),
633640
],
@@ -649,7 +656,8 @@ def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.No
649656

650657
if isinstance(mesh, Types.File3D):
651658
# Handle File3D input - save BytesIO data to output folder
652-
f = f"{filename}_{counter:05}_.glb"
659+
ext = mesh.format or "glb"
660+
f = f"{filename}_{counter:05}_.{ext}"
653661
mesh.save_to(os.path.join(full_output_folder, f))
654662
results.append({
655663
"filename": f,

comfy_extras/nodes_load_3d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def define_schema(cls):
4545
IO.Image.Output(display_name="normal"),
4646
IO.Load3DCamera.Output(display_name="camera_info"),
4747
IO.Video.Output(display_name="recording_video"),
48+
IO.File3DAny.Output(display_name="model_3d"),
4849
],
4950
)
5051

@@ -66,7 +67,8 @@ def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
6667

6768
video = InputImpl.VideoFromFile(recording_video_path)
6869

69-
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video)
70+
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
71+
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d)
7072

7173
process = execute # TODO: remove
7274

0 commit comments

Comments
 (0)