Skip to content

Commit ae44801

Browse files
committed
Best quality generation and more optionss
1 parent 17aaea1 commit ae44801

10 files changed

+1514
-280
lines changed

__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
from .ace_step_prompt_gen import NODE_CLASS_MAPPINGS as PROMPT_MAPPINGS, NODE_DISPLAY_NAMES as PROMPT_NAMES
66
from .lyrics_nodes import NODE_CLASS_MAPPINGS as LYRICS_MAPPINGS, NODE_DISPLAY_NAMES as LYRICS_NAMES
77
from .ace_step_save_text import NODE_CLASS_MAPPINGS as SAVETEXT_MAPPINGS, NODE_DISPLAY_NAMES as SAVETEXT_NAMES
8+
from .ace_step_post_process import NODE_CLASS_MAPPINGS as POSTPROCESS_MAPPINGS, NODE_DISPLAY_NAMES as POSTPROCESS_NAMES
9+
from .ace_step_vocoder_adapter import NODE_CLASS_MAPPINGS as VOCODER_ADAPTER_MAPPINGS, NODE_DISPLAY_NAMES as VOCODER_ADAPTER_NAMES
810
# DISABLED: optimization_nodes removed (torch.compile incompatibility)
911
# from .optimization_nodes import NODE_CLASS_MAPPINGS as OPT_MAPPINGS, NODE_DISPLAY_NAMES as OPT_NAMES
1012
# DISABLED: torch_compile_node causes incompatibility with ACE-Step
1113
# from .torch_compile_node import NODE_CLASS_MAPPINGS as COMPILE_MAPPINGS, NODE_DISPLAY_NAMES as COMPILE_NAMES
1214

1315
# Combine all node mappings
14-
NODE_CLASS_MAPPINGS = {**KSAMPLER_MAPPINGS, **PROMPT_MAPPINGS, **LYRICS_MAPPINGS, **SAVETEXT_MAPPINGS}
15-
NODE_DISPLAY_NAMES = {**KSAMPLER_NAMES, **PROMPT_NAMES, **LYRICS_NAMES, **SAVETEXT_NAMES}
16+
NODE_CLASS_MAPPINGS = {**KSAMPLER_MAPPINGS, **PROMPT_MAPPINGS, **LYRICS_MAPPINGS, **SAVETEXT_MAPPINGS, **POSTPROCESS_MAPPINGS, **VOCODER_ADAPTER_MAPPINGS}
17+
NODE_DISPLAY_NAMES = {**KSAMPLER_NAMES, **PROMPT_NAMES, **LYRICS_NAMES, **SAVETEXT_NAMES, **POSTPROCESS_NAMES, **VOCODER_ADAPTER_NAMES}
1618

1719
# Register custom samplers with ComfyUI
1820
def add_samplers():

ace_step_ksampler.py

Lines changed: 223 additions & 7 deletions
Large diffs are not rendered by default.

ace_step_post_process.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import torch
2+
import logging
3+
4+
logger = logging.getLogger(__name__)
5+
6+
7+
class AceStepPostProcess:
8+
"""Simple post-process node focused on removing metallic sibilance and adding soft breath mix."""
9+
10+
@classmethod
11+
def INPUT_TYPES(cls):
12+
return {
13+
"required": {
14+
"audio": ("AUDIO",),
15+
},
16+
"optional": {
17+
"de_esser_strength": ("FLOAT", {"default": 0.12, "min": 0.0, "max": 0.6, "step": 0.01}),
18+
"spectral_smoothing": ("FLOAT", {"default": 0.08, "min": 0.0, "max": 0.5, "step": 0.01}),
19+
"breath_mix": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 0.2, "step": 0.01}),
20+
"breath_audio": ("AUDIO",),
21+
},
22+
}
23+
24+
RETURN_TYPES = ("AUDIO",)
25+
RETURN_NAMES = ("audio",)
26+
FUNCTION = "process"
27+
CATEGORY = "JK AceStep Nodes/PostProcess"
28+
29+
def process(self, audio, de_esser_strength=0.12, spectral_smoothing=0.08, breath_mix=0.0, breath_audio=None):
30+
try:
31+
waveform = audio["waveform"] if isinstance(audio, dict) and "waveform" in audio else audio
32+
if isinstance(waveform, torch.Tensor):
33+
x = waveform
34+
# Expect shape [B, C, T]
35+
if x.dim() == 2:
36+
x = x.unsqueeze(1)
37+
38+
B, C, T = x.shape
39+
# Short-time Fourier Transform parameters
40+
n_fft = 2048
41+
hop_length = 512
42+
win = torch.hann_window(n_fft).to(x.device)
43+
# Apply STFT per channel
44+
out = x.clone()
45+
for b in range(B):
46+
for c in range(C):
47+
sig = x[b, c]
48+
stft = torch.stft(sig, n_fft=n_fft, hop_length=hop_length, win_length=n_fft, window=win, return_complex=True)
49+
mag = torch.abs(stft)
50+
phase = torch.angle(stft)
51+
# Apply de-esser: reduce energy above 6kHz proportionally
52+
sr = audio.get('sample_rate', 44100) if isinstance(audio, dict) else 44100
53+
freqs = torch.fft.rfftfreq(n_fft, 1.0/sr).to(x.device)
54+
mask = (freqs > 6000).float().view(1, -1)
55+
mag = mag * (1.0 - (de_esser_strength * mask))
56+
# Spectral smoothing across frequency
57+
if spectral_smoothing > 0.0:
58+
kernel = torch.tensor([0.25, 0.5, 0.25], dtype=mag.dtype, device=mag.device).view(1, 1, -1)
59+
padded = torch.nn.functional.pad(mag, (1, 1, 0, 0), mode='reflect')
60+
smoothed_mag = torch.nn.functional.conv1d(padded, kernel, padding=0)
61+
mag = (1.0 - spectral_smoothing) * mag + spectral_smoothing * smoothed_mag
62+
complex_spec = torch.polar(mag, phase)
63+
sig_rec = torch.istft(complex_spec, n_fft=n_fft, hop_length=hop_length, win_length=n_fft, window=win, length=T)
64+
out[b, c] = sig_rec
65+
# Add breath overlay if provided
66+
if breath_mix > 0.0 and breath_audio is not None and isinstance(breath_audio, dict):
67+
breath_wave = breath_audio.get('waveform', None)
68+
if breath_wave is not None and isinstance(breath_wave, torch.Tensor):
69+
# Only add first channel
70+
out[:, 0, :min(out.shape[2], breath_wave.shape[-1])] += breath_mix * breath_wave[:, 0, :out.shape[2]]
71+
72+
# Re-normalize a tiny bit
73+
out = out / (out.abs().max().clamp(min=1e-5))
74+
if isinstance(audio, dict):
75+
audio["waveform"] = out
76+
return (audio,)
77+
else:
78+
return ({"waveform": out, "sample_rate": audio.get('sample_rate', 44100)},)
79+
else:
80+
logger.warning("Input audio is not a torch.Tensor, skipping post-processing.")
81+
return (audio,)
82+
except Exception as e:
83+
logger.error(f"Post processing failed: {e}")
84+
return (audio,)
85+
86+
87+
NODE_CLASS_MAPPINGS = {
88+
"AceStepPostProcess": AceStepPostProcess,
89+
}
90+
91+
NODE_DISPLAY_NAMES = {
92+
"AceStepPostProcess": "Ace-Step Post Process",
93+
}

ace_step_prompt_gen.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,19 +257,90 @@ def INPUT_TYPES(cls):
257257
},
258258
),
259259
},
260+
"optional": {
261+
"voice_style": (
262+
[
263+
"none",
264+
"natural_female",
265+
"breathy_female",
266+
"powerful_female",
267+
"ethereal_female",
268+
"soulful_female",
269+
"deep_female",
270+
"natural_male",
271+
"breathy_male",
272+
"powerful_male",
273+
"deep_male",
274+
"soulful_male",
275+
"tenor_male",
276+
"baritone_male",
277+
"reference_singer",
278+
"androgynous",
279+
"vocal_blend",
280+
"robotic_vocal"
281+
],
282+
{
283+
"default": "none",
284+
"tooltip": "Optional voice style hints that are appended to the prompt to improve vocal realism. Female (6 options), Male (6 options), Blended (2 options), Robotic (1), None (auto)."
285+
}
286+
),
287+
},
260288
}
261289

262290
RETURN_TYPES = ("STRING", "STRING")
263291
RETURN_NAMES = ("prompt", "template")
264292
FUNCTION = "generate"
265293
CATEGORY = "JK AceStep Nodes/Prompt"
266294

267-
def generate(self, style: str, extra: str = ""):
295+
def generate(self, style: str, extra: str = "", voice_style: str = "none"):
268296
template = STYLE_PROMPTS.get(style, "")
297+
voice_hint = ""
298+
299+
# FEMALE VOCALS (6 options)
300+
if voice_style == "natural_female":
301+
voice_hint = "natural female voice with micro pitch variation, soft breath, realistic vibrato and avoid robotic quantization"
302+
elif voice_style == "breathy_female":
303+
voice_hint = "breathy female voice, intimate mic proximity, audible breaths, warm vowel resonances, minimal autotune"
304+
elif voice_style == "powerful_female":
305+
voice_hint = "powerful female lead vocal, energetic performance, controlled vibrato, clear consonant articulation, live vocal tone"
306+
elif voice_style == "ethereal_female":
307+
voice_hint = "ethereal female voice, airy and light, floating above the beat, delicate phrasing, spacious reverb, dreamy quality"
308+
elif voice_style == "soulful_female":
309+
voice_hint = "soulful female voice, rich emotional depth, warm tone, blues influences, expressive phrasing, powerful presence"
310+
elif voice_style == "deep_female":
311+
voice_hint = "deep female mezzo-soprano voice, lower register, sultry tone, sophisticated delivery, jazz influences"
312+
313+
# MALE VOCALS (6 options)
314+
elif voice_style == "natural_male":
315+
voice_hint = "natural male voice with micro pitch variation and natural prosody, warm tone, realistic breath"
316+
elif voice_style == "breathy_male":
317+
voice_hint = "breathy male voice, intimate vocal delivery, audible breath texture, vulnerable performance, close-mic warmth"
318+
elif voice_style == "powerful_male":
319+
voice_hint = "powerful male lead vocal, strong projection, controlled vibrato, clear articulation, commanding presence"
320+
elif voice_style == "deep_male":
321+
voice_hint = "deep male baritone-bass voice, rich resonance, lower register dominance, warm sonority, authoritative tone"
322+
elif voice_style == "soulful_male":
323+
voice_hint = "soulful male voice, emotional depth, R&B influences, smooth phrasing, expressive delivery, warm presence"
324+
elif voice_style == "tenor_male":
325+
voice_hint = "bright tenor male voice, soaring high notes, clear articulation, pop sensibility, energetic performance"
326+
327+
# BLENDED VOCALS (2 options)
328+
elif voice_style == "reference_singer":
329+
voice_hint = "use a reference lead singer performance: natural, human vocal delivery, no robotic artifacts"
330+
elif voice_style == "androgynous":
331+
voice_hint = "androgynous voice quality, neutral gender presentation, balanced tone between male and female characteristics"
332+
elif voice_style == "vocal_blend":
333+
voice_hint = "layered vocal blend with multiple voices, rich harmonic texture, complementary vocal ranges, ensemble quality"
334+
# ROBOTIC VOCAL
335+
elif voice_style == "robotic_vocal":
336+
voice_hint = "robotic vocal style, vocoder or autotune effect, synthetic timbre, precise pitch, electronic articulation, minimal human expressiveness, classic EDM/Daft Punk/house vocal texture"
337+
338+
parts = [template] if template else []
269339
if extra.strip():
270-
final_prompt = f"{template}. {extra.strip()}"
271-
else:
272-
final_prompt = template
340+
parts.append(extra.strip())
341+
if voice_hint:
342+
parts.append(voice_hint)
343+
final_prompt = ". ".join(parts)
273344
return (final_prompt, template)
274345

275346

ace_step_vocoder_adapter.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""Vocoder Adapter Node
2+
3+
This node is a minimal adapter to help feed Ace-Step latents or decoded audio into a vocoder.
4+
It tries to detect the expected input type of the provided vocoder object and calls the right API.
5+
Supported flows:
6+
- Latent -> VAE -> waveform -> mel -> vocoder
7+
- Latent -> mel (if latent appears to be mel) -> vocoder
8+
- Clean waveform -> vocoder (if vocoder expects waveform for final polish)
9+
10+
Notes:
11+
- Vocoder objects must be Python objects exposed to ComfyUI nodes (i.e., selected via a model node).
12+
- The node supports `mel_transform` using `librosa` if present, otherwise uses Torch-based mel filter.
13+
"""
14+
import logging
15+
import torch
16+
import numpy as np
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class AceStepVocoderAdapter:
22+
@classmethod
23+
def INPUT_TYPES(cls):
24+
return {
25+
"required": {
26+
"vocoder": ("MODEL",),
27+
"vae": ("VAE",),
28+
"latent": ("LATENT",),
29+
},
30+
"optional": {
31+
"sample_rate": ("INT", {"default": 44100}),
32+
"n_mels": ("INT", {"default": 128}),
33+
"n_fft": ("INT", {"default": 2048}),
34+
"hop_length": ("INT", {"default": 512}),
35+
}
36+
}
37+
38+
RETURN_TYPES = ("AUDIO",)
39+
RETURN_NAMES = ("audio",)
40+
FUNCTION = "adapt"
41+
CATEGORY = "JK AceStep Nodes/Vocoder"
42+
43+
def _to_mel_torch(self, waveform, sr=44100, n_fft=2048, hop=512, n_mels=128):
44+
# waveform: [B, C, T] or [T]
45+
import torch.nn.functional as F
46+
if waveform.dim() == 3:
47+
wav = waveform[:, 0]
48+
elif waveform.dim() == 2:
49+
wav = waveform[:, 0]
50+
else:
51+
wav = waveform.unsqueeze(0)
52+
try:
53+
import torchaudio
54+
mel_spec = torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=n_fft, hop_length=hop, n_mels=n_mels)(wav)
55+
log_mel = torch.log(torch.clamp(mel_spec, 1e-9))
56+
return log_mel
57+
except Exception:
58+
# Lowest-effort fallback: compute STFT magnitude and map bins
59+
stft = torch.stft(wav, n_fft=n_fft, hop_length=hop, return_complex=True)
60+
mag = torch.abs(stft)
61+
# naive spectral-to-mel via linear downsampling
62+
mel = F.interpolate(mag.unsqueeze(1), size=n_mels, mode='linear').squeeze(1)
63+
return torch.log(torch.clamp(mel, 1e-9))
64+
65+
def adapt(self, vocoder, vae, latent, sample_rate=44100, n_mels=None, n_fft=2048, hop_length=512):
66+
# optional parameters set by node UI can be passed through **kwargs later
67+
# Try to introspect vocoder for n_mels/hop_length/etc
68+
if n_mels is None:
69+
n_mels = getattr(vocoder, 'n_mels', None)
70+
if n_mels is None and hasattr(vocoder, 'config') and getattr(vocoder.config, 'n_mels', None) is not None:
71+
n_mels = vocoder.config.n_mels
72+
if n_mels is None:
73+
n_mels = 128
74+
n_fft = getattr(vocoder, 'n_fft', n_fft)
75+
hop = getattr(vocoder, 'hop_length', hop_length)
76+
# Step 1: If 'latent' is a dict and contains 'samples', try decode
77+
audio_wave = None
78+
if isinstance(latent, dict) and 'samples' in latent:
79+
try:
80+
audio_wave = vae.decode(latent['samples']).movedim(-1, 1)
81+
except Exception as e:
82+
logger.warning(f"VAE decode failed: {e}")
83+
audio_wave = None
84+
85+
# Step 1b: If the latent seems to be mel already (many vocoders expect mel)
86+
latent_is_mel = False
87+
if isinstance(latent, dict) and 'samples' in latent:
88+
samples = latent['samples']
89+
# Heuristic: if last frequency dim is <= n_mels and reasonably small, it's probably a mel
90+
if samples.dim() == 4 and samples.shape[-1] <= max(128, n_mels):
91+
latent_is_mel = True
92+
elif samples.dim() == 3 and samples.shape[1] == n_mels:
93+
latent_is_mel = True
94+
95+
# Step 2: Form mel if needed
96+
mel = None
97+
if latent_is_mel:
98+
samples = latent['samples']
99+
if samples.dim() == 4:
100+
# [B, C, T, F] -> collapse C by mean -> [B, T, F], then permute to [B, F, T]
101+
mel = samples.mean(dim=1).permute(0, 2, 1)
102+
# interpolate frequencies to n_mels if different
103+
if mel.shape[1] != n_mels:
104+
mel = torch.nn.functional.interpolate(mel.unsqueeze(1), size=(n_mels, mel.shape[2]), mode='bilinear', align_corners=False).squeeze(1)
105+
elif samples.dim() == 3:
106+
# [B, C, T] - assume channel dim is n_mels
107+
if samples.shape[1] != n_mels:
108+
mel = torch.nn.functional.interpolate(samples.unsqueeze(1), size=(n_mels, samples.shape[2]), mode='bilinear', align_corners=False).squeeze(1)
109+
else:
110+
mel = samples
111+
else:
112+
mel = samples
113+
elif audio_wave is not None:
114+
# if needed, resample audio to vocoder sampling rate
115+
vocoder_sr = getattr(vocoder, 'sampling_rate', getattr(vocoder, 'sample_rate', sample_rate))
116+
if audio_wave is not None and hasattr(audio_wave, 'shape') and int(vocoder_sr) != int(sample_rate):
117+
try:
118+
import torchaudio
119+
resampler = torchaudio.transforms.Resample(orig_freq=int(sample_rate), new_freq=int(vocoder_sr))
120+
audio_wave = resampler(audio_wave)
121+
sample_rate = int(vocoder_sr)
122+
except Exception:
123+
# fallback: log warning and continue
124+
logger.warning('Resample failed: torchaudio not available or error during resample')
125+
mel = self._to_mel_torch(audio_wave, sr=sample_rate, n_fft=n_fft, hop=hop, n_mels=n_mels)
126+
else:
127+
# last resort: try using the latent values collapsed
128+
try:
129+
s = latent['samples']
130+
mel = s.mean(dim=1) if s.dim() >= 4 else s
131+
except Exception as e:
132+
logger.error(f"Failed to derive mel: {e}")
133+
raise RuntimeError("Unable to derive mel from latent")
134+
135+
# Step 3: Try call the vocoder
136+
if mel is None:
137+
raise RuntimeError('Failed to produce mel for vocoder input')
138+
139+
# Many vocoders accept (batch, n_mels, T) or (n_mels, T)
140+
if mel.dim() == 2:
141+
batched = mel.unsqueeze(0)
142+
else:
143+
batched = mel
144+
145+
# Try common method names
146+
# Several vocoders expect inputs with log mel spec shape [B, n_mels, T]
147+
# If mel is not yet log-scaled, apply log
148+
try:
149+
if batched.min() >= 0:
150+
batched = torch.log(torch.clamp(batched, min=1e-9))
151+
except Exception:
152+
pass
153+
154+
if hasattr(vocoder, 'infer'):
155+
try:
156+
out = vocoder.infer(batched)
157+
return ({'waveform': out, 'sample_rate': sample_rate},)
158+
except Exception as e:
159+
logger.warning(f"vocoder.infer() failed: {e}")
160+
if hasattr(vocoder, 'synthesize'):
161+
try:
162+
out = vocoder.synthesize(batched)
163+
return ({'waveform': out, 'sample_rate': sample_rate},)
164+
except Exception as e:
165+
logger.warning(f"vocoder.synthesize() failed: {e}")
166+
if hasattr(vocoder, 'decode'):
167+
try:
168+
out = vocoder.decode(batched)
169+
return ({'waveform': out, 'sample_rate': sample_rate},)
170+
except Exception as e:
171+
logger.warning(f"vocoder.decode() failed: {e}")
172+
173+
# If vocoder is a function, call it directly
174+
if callable(vocoder):
175+
try:
176+
out = vocoder(batched)
177+
return ({'waveform': out, 'sample_rate': sample_rate},)
178+
except Exception as e:
179+
logger.warning(f"vocoder callable failed: {e}")
180+
181+
logger.error('No known vocoder API found. Please use a vocoder object exposing infer/synthesize/decode or pass a callable.')
182+
raise RuntimeError('Unsupported vocoder object')
183+
184+
185+
NODE_CLASS_MAPPINGS = {
186+
'AceStepVocoderAdapter': AceStepVocoderAdapter,
187+
}
188+
189+
NODE_DISPLAY_NAMES = {
190+
'AceStepVocoderAdapter': 'Ace-Step Vocoder Adapter',
191+
}

0 commit comments

Comments
 (0)