Skip to content

Commit fd79d32

Browse files
authored
Add new audio nodes (#9908)
* Add new audio nodes - TrimAudioDuration - SplitAudioChannels - AudioConcat - AudioMerge - AudioAdjustVolume * Update nodes_audio.py * Add EmptyAudio -node * Change duration to Float (allows sub seconds)
1 parent 341b4ad commit fd79d32

File tree

1 file changed

+223
-0
lines changed

1 file changed

+223
-0
lines changed

comfy_extras/nodes_audio.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import random
1212
import hashlib
1313
import node_helpers
14+
import logging
1415
from comfy.cli_args import args
1516
from comfy.comfy_types import FileLocator
1617

@@ -364,6 +365,216 @@ def load(self, audio):
364365
return (audio, )
365366

366367

368+
class TrimAudioDuration:
369+
@classmethod
370+
def INPUT_TYPES(cls):
371+
return {
372+
"required": {
373+
"audio": ("AUDIO",),
374+
"start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
375+
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
376+
},
377+
}
378+
379+
FUNCTION = "trim"
380+
RETURN_TYPES = ("AUDIO",)
381+
CATEGORY = "audio"
382+
DESCRIPTION = "Trim audio tensor into chosen time range."
383+
384+
def trim(self, audio, start_index, duration):
385+
waveform = audio["waveform"]
386+
sample_rate = audio["sample_rate"]
387+
audio_length = waveform.shape[-1]
388+
389+
if start_index < 0:
390+
start_frame = audio_length + int(round(start_index * sample_rate))
391+
else:
392+
start_frame = int(round(start_index * sample_rate))
393+
start_frame = max(0, min(start_frame, audio_length - 1))
394+
395+
end_frame = start_frame + int(round(duration * sample_rate))
396+
end_frame = max(0, min(end_frame, audio_length))
397+
398+
if start_frame >= end_frame:
399+
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
400+
401+
return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
402+
403+
404+
class SplitAudioChannels:
405+
@classmethod
406+
def INPUT_TYPES(s):
407+
return {"required": {
408+
"audio": ("AUDIO",),
409+
}}
410+
411+
RETURN_TYPES = ("AUDIO", "AUDIO")
412+
RETURN_NAMES = ("left", "right")
413+
FUNCTION = "separate"
414+
CATEGORY = "audio"
415+
DESCRIPTION = "Separates the audio into left and right channels."
416+
417+
def separate(self, audio):
418+
waveform = audio["waveform"]
419+
sample_rate = audio["sample_rate"]
420+
421+
if waveform.shape[1] != 2:
422+
raise ValueError("AudioSplit: Input audio has only one channel.")
423+
424+
left_channel = waveform[..., 0:1, :]
425+
right_channel = waveform[..., 1:2, :]
426+
427+
return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
428+
429+
430+
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
431+
if sample_rate_1 != sample_rate_2:
432+
if sample_rate_1 > sample_rate_2:
433+
waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1)
434+
output_sample_rate = sample_rate_1
435+
logging.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.")
436+
else:
437+
waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2)
438+
output_sample_rate = sample_rate_2
439+
logging.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.")
440+
else:
441+
output_sample_rate = sample_rate_1
442+
return waveform_1, waveform_2, output_sample_rate
443+
444+
445+
class AudioConcat:
446+
@classmethod
447+
def INPUT_TYPES(s):
448+
return {"required": {
449+
"audio1": ("AUDIO",),
450+
"audio2": ("AUDIO",),
451+
"direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
452+
}}
453+
454+
RETURN_TYPES = ("AUDIO",)
455+
FUNCTION = "concat"
456+
CATEGORY = "audio"
457+
DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
458+
459+
def concat(self, audio1, audio2, direction):
460+
waveform_1 = audio1["waveform"]
461+
waveform_2 = audio2["waveform"]
462+
sample_rate_1 = audio1["sample_rate"]
463+
sample_rate_2 = audio2["sample_rate"]
464+
465+
if waveform_1.shape[1] == 1:
466+
waveform_1 = waveform_1.repeat(1, 2, 1)
467+
logging.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.")
468+
if waveform_2.shape[1] == 1:
469+
waveform_2 = waveform_2.repeat(1, 2, 1)
470+
logging.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.")
471+
472+
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
473+
474+
if direction == 'after':
475+
concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2)
476+
elif direction == 'before':
477+
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
478+
479+
return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
480+
481+
482+
class AudioMerge:
483+
@classmethod
484+
def INPUT_TYPES(cls):
485+
return {
486+
"required": {
487+
"audio1": ("AUDIO",),
488+
"audio2": ("AUDIO",),
489+
"merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
490+
},
491+
}
492+
493+
FUNCTION = "merge"
494+
RETURN_TYPES = ("AUDIO",)
495+
CATEGORY = "audio"
496+
DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
497+
498+
def merge(self, audio1, audio2, merge_method):
499+
waveform_1 = audio1["waveform"]
500+
waveform_2 = audio2["waveform"]
501+
sample_rate_1 = audio1["sample_rate"]
502+
sample_rate_2 = audio2["sample_rate"]
503+
504+
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
505+
506+
length_1 = waveform_1.shape[-1]
507+
length_2 = waveform_2.shape[-1]
508+
509+
if length_2 > length_1:
510+
logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
511+
waveform_2 = waveform_2[..., :length_1]
512+
elif length_2 < length_1:
513+
logging.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.")
514+
pad_shape = list(waveform_2.shape)
515+
pad_shape[-1] = length_1 - length_2
516+
pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device)
517+
waveform_2 = torch.cat((waveform_2, pad_tensor), dim=-1)
518+
519+
if merge_method == "add":
520+
waveform = waveform_1 + waveform_2
521+
elif merge_method == "subtract":
522+
waveform = waveform_1 - waveform_2
523+
elif merge_method == "multiply":
524+
waveform = waveform_1 * waveform_2
525+
elif merge_method == "mean":
526+
waveform = (waveform_1 + waveform_2) / 2
527+
528+
max_val = waveform.abs().max()
529+
if max_val > 1.0:
530+
waveform = waveform / max_val
531+
532+
return ({"waveform": waveform, "sample_rate": output_sample_rate},)
533+
534+
535+
class AudioAdjustVolume:
536+
@classmethod
537+
def INPUT_TYPES(s):
538+
return {"required": {
539+
"audio": ("AUDIO",),
540+
"volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
541+
}}
542+
543+
RETURN_TYPES = ("AUDIO",)
544+
FUNCTION = "adjust_volume"
545+
CATEGORY = "audio"
546+
547+
def adjust_volume(self, audio, volume):
548+
if volume == 0:
549+
return (audio,)
550+
waveform = audio["waveform"]
551+
sample_rate = audio["sample_rate"]
552+
553+
gain = 10 ** (volume / 20)
554+
waveform = waveform * gain
555+
556+
return ({"waveform": waveform, "sample_rate": sample_rate},)
557+
558+
559+
class EmptyAudio:
560+
@classmethod
561+
def INPUT_TYPES(s):
562+
return {"required": {
563+
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
564+
"sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
565+
"channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
566+
}}
567+
568+
RETURN_TYPES = ("AUDIO",)
569+
FUNCTION = "create_empty_audio"
570+
CATEGORY = "audio"
571+
572+
def create_empty_audio(self, duration, sample_rate, channels):
573+
num_samples = int(round(duration * sample_rate))
574+
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
575+
return ({"waveform": waveform, "sample_rate": sample_rate},)
576+
577+
367578
NODE_CLASS_MAPPINGS = {
368579
"EmptyLatentAudio": EmptyLatentAudio,
369580
"VAEEncodeAudio": VAEEncodeAudio,
@@ -375,6 +586,12 @@ def load(self, audio):
375586
"PreviewAudio": PreviewAudio,
376587
"ConditioningStableAudio": ConditioningStableAudio,
377588
"RecordAudio": RecordAudio,
589+
"TrimAudioDuration": TrimAudioDuration,
590+
"SplitAudioChannels": SplitAudioChannels,
591+
"AudioConcat": AudioConcat,
592+
"AudioMerge": AudioMerge,
593+
"AudioAdjustVolume": AudioAdjustVolume,
594+
"EmptyAudio": EmptyAudio,
378595
}
379596

380597
NODE_DISPLAY_NAME_MAPPINGS = {
@@ -387,4 +604,10 @@ def load(self, audio):
387604
"SaveAudioMP3": "Save Audio (MP3)",
388605
"SaveAudioOpus": "Save Audio (Opus)",
389606
"RecordAudio": "Record Audio",
607+
"TrimAudioDuration": "Trim Audio Duration",
608+
"SplitAudioChannels": "Split Audio Channels",
609+
"AudioConcat": "Audio Concat",
610+
"AudioMerge": "Audio Merge",
611+
"AudioAdjustVolume": "Audio Adjust Volume",
612+
"EmptyAudio": "Empty Audio",
390613
}

0 commit comments

Comments
 (0)