Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
254622d
init
yousef-rafat Sep 5, 2025
1cff9b8
Merge branch 'master' into yousef-higgsv2
yousef-rafat Sep 5, 2025
df4b6a2
removed test files
yousef-rafat Sep 5, 2025
6e9335d
Merge branch 'yousef-higgsv2' of https://github.com/yousef-rafat/Comf…
yousef-rafat Sep 5, 2025
57c15f9
styling fixes
yousef-rafat Sep 5, 2025
f8d4891
additional styling
yousef-rafat Sep 5, 2025
233e441
.
yousef-rafat Sep 5, 2025
6412422
bug fixes + added some features
yousef-rafat Sep 8, 2025
5191fb2
Merge branch 'master' into yousef-higgsv2
yousef-rafat Sep 9, 2025
2ac8999
final
yousef-rafat Sep 9, 2025
fee1e57
Merge branch 'yousef-higgsv2' of https://github.com/yousef-rafat/Comf…
yousef-rafat Sep 9, 2025
86df359
Merge branch 'master' into yousef-higgsv2
yousef-rafat Sep 10, 2025
cf18a06
Merge branch 'master' into yousef-higgsv2
yousef-rafat Sep 12, 2025
b17463d
Update supported_models.py
yousef-rafat Sep 12, 2025
b583c39
Merge branch 'master' into yousef-higgsv2
yousef-rafat Sep 15, 2025
acdb10a
Merge branch 'master' into yousef-higgsv2
yousef-rafat Sep 17, 2025
ca7ecae
Merge branch 'master' into yousef-higgsv2
yousef-rafat Sep 20, 2025
07cfed8
Merge branch 'master' into yousef-higgsv2
yousef-rafat Oct 8, 2025
203849f
removed cuda graphs + changes to loudness.py
yousef-rafat Oct 24, 2025
b80cbec
removed unused import
yousef-rafat Oct 24, 2025
faa8667
forgot cudagraph import in higgsv2 model.py
yousef-rafat Oct 24, 2025
01b3252
.
yousef-rafat Oct 24, 2025
76fdb4b
cache updates
yousef-rafat Oct 24, 2025
907ed37
ruff checks
yousef-rafat Oct 24, 2025
90e7674
moved autoregressive nodes into an extra file
yousef-rafat Oct 29, 2025
01e93c7
Merge branch 'master' into yousef-higgsv2
yousef-rafat Nov 14, 2025
470bd9e
changed return type to generated_tokens
yousef-rafat Nov 20, 2025
21ac52a
Merge branch 'master' into yousef-higgsv2
yousef-rafat Dec 3, 2025
c7851ae
updated name to generated_tokens
yousef-rafat Dec 3, 2025
4ece4ee
Merge branch 'yousef-higgsv2' of https://github.com/yousef-rafat/Comf…
yousef-rafat Dec 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
734 changes: 734 additions & 0 deletions comfy/autoregressive_sampling.py

Large diffs are not rendered by default.

311 changes: 311 additions & 0 deletions comfy/ldm/higgsv2/loudness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
import copy
import math
import torch
import scipy
import numpy as np
import torch.nn.functional as F
from typing import Optional, List

# defaulted to the new pytorch api
def _new_rfft(x: torch.Tensor):
z = torch.fft.rfft(x, dim=-1)
return torch.view_as_real(z)

def _new_irfft(x: torch.Tensor, length: int):
x = torch.view_as_complex(x)
return torch.fft.irfft(x, length, dim=-1)

def _compl_mul_conjugate(a: torch.Tensor, b: torch.Tensor):
# changed this function to use the pytorch api
return torch.view_as_real(torch.view_as_complex(a) * torch.view_as_complex(b).conj())

def unfold(input, kernel_size: int, stride: int):

shape = list(input.shape)
length = shape.pop(-1)

n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1
tgt_length = (n_frames - 1) * stride + kernel_size

padded = F.pad(input, (0, tgt_length - length)).contiguous()
strides: List[int] = []

for dim in range(padded.dim()):
strides.append(padded.stride(dim))

last_stride = strides.pop(-1)
assert last_stride == 1, 'data should be contiguous'

strides = strides + [stride, 1]
return padded.as_strided(shape + [n_frames, kernel_size], strides)

# convert the signal and filter to frequency domain, multiply them, then inverse FFT to get back to time-domain
# faster than a sliding window over time-domain.
def fft_conv1d(
input: torch.Tensor, weight: torch.Tensor,
bias: Optional[torch.Tensor] = None, stride: int = 1, padding: int = 0,
block_ratio: float = 5):

input = F.pad(input, (padding, padding))
batch, _, length = input.shape
out_channels, _, kernel_size = weight.shape

_rfft = _new_rfft
_irfft = _new_irfft

if length < kernel_size:
raise RuntimeError(f"Input should be at least as large as the kernel size {kernel_size}, "
f"but it is only {length} samples long.")
if block_ratio < 1:
raise RuntimeError("Block ratio must be greater than 1.")

# We are going to process the input blocks by blocks, as for some reason it is faster
# and less memory intensive (I think the culprit is `torch.einsum`.
block_size: int = min(int(kernel_size * block_ratio), length)
fold_stride = block_size - kernel_size + 1

# replaces to_pad
weight = F.pad(weight, (0, block_size - weight.shape[-1]), mode = "constant", value = 0)
weight_z = _rfft(weight)

# We pad the input and get the different frames, on which
frames = unfold(input, block_size, fold_stride)

frames_z = _rfft(frames)
out_z = _compl_mul_conjugate(frames_z, weight_z)
out = _irfft(out_z, block_size)
# The last bit is invalid, because FFT will do a circular convolution.
out = out[..., :-kernel_size + 1]
out = out.reshape(batch, out_channels, -1)
out = out[..., ::stride]
target_length = (length - kernel_size) // stride + 1
out = out[..., :target_length]
if bias is not None:
out += bias[:, None]
return out

class IIRfilter(object):

def __init__(self, G, Q, fc, rate, filter_type, passband_gain=1.0):
G, Q, fc, rate = [t if isinstance(t, torch.Tensor) else torch.as_tensor(t) for t in (G, Q, fc, rate)]
self.G = G
self.Q = Q
self.fc = fc
self.rate = rate
self.filter_type = filter_type
self.passband_gain = passband_gain

def generate_coefficients(self):

A = 10**(self.G/40.0)
w0 = 2.0 * torch.pi * (self.fc / self.rate)
alpha = torch.sin(w0) / (2.0 * self.Q)

if self.filter_type == 'high_shelf':
b0 = A * ( (A+1) + (A-1) * torch.cos(w0) + 2 * torch.sqrt(A) * alpha )
b1 = -2 * A * ( (A-1) + (A+1) * torch.cos(w0) )
b2 = A * ( (A+1) + (A-1) * torch.cos(w0) - 2 * torch.sqrt(A) * alpha )
a0 = (A+1) - (A-1) * torch.cos(w0) + 2 * torch.sqrt(A) * alpha
a1 = 2 * ( (A-1) - (A+1) * torch.cos(w0) )
a2 = (A+1) - (A-1) * torch.cos(w0) - 2 * torch.sqrt(A) * alpha

elif self.filter_type == 'high_pass':
b0 = (1 + torch.cos(w0))/2
b1 = -(1 + torch.cos(w0))
b2 = (1 + torch.cos(w0))/2
a0 = 1 + alpha
a1 = -2 * torch.cos(w0)
a2 = 1 - alpha

return torch.tensor([b0, b1, b2])/a0, torch.tensor([a0, a1, a2])/a0

def apply_filter(self, data):
return self.passband_gain * scipy.signal.lfilter(self.b, self.a, data)

@property
def b_and_a(self):
return self.generate_coefficients()

class Meter(torch.nn.Module):

def __init__(
self,
rate: int,
filter_class: str = "K-weighting",
block_size: float = 0.400,
zeros: int = 512,
use_fir: bool = False,
):
super().__init__()

self.rate = rate
self.filter_class = filter_class
self.block_size = block_size
self.use_fir = use_fir

G = torch.from_numpy(np.array([1.0, 1.0, 1.0, 1.41, 1.41]))
self.register_buffer("G", G)

self._filters = {}
self._filters['high_shelf'] = IIRfilter(4.0, 1/np.sqrt(2), 1500.0, self.rate, 'high_shelf')
self._filters['high_pass'] = IIRfilter(0.0, 0.5, 38.0, self.rate, 'high_pass')

# Compute impulse responses so that filtering is fast via
# a convolution at runtime, on GPU, unlike lfilter.
impulse = np.zeros((zeros,))
impulse[..., 0] = 1.0

firs = np.zeros((len(self._filters), 1, zeros))
passband_gain = torch.tensor([filter.passband_gain for filter in self._filters.values()])

for i, (_, filter_stage) in enumerate(self._filters.items()):
b, a = filter_stage.b_and_a
firs[i] = scipy.signal.lfilter(b.numpy(), a.numpy(), impulse)

firs = torch.from_numpy(firs[..., ::-1].copy()).float()

self.register_buffer("firs", firs)
self.register_buffer("passband_gain", passband_gain)

def apply_filter_fir(self, data: torch.Tensor):

# Data is of shape (nb, nch, nt)
# Reshape to (nb*nch, 1, nt)
nb, nt, nch = data.shape
data = data.permute(0, 2, 1)
data = data.reshape(nb * nch, 1, nt)

# Apply padding
pad_length = self.firs.shape[-1]

# Apply filtering in sequence
for i in range(self.firs.shape[0]):
data = F.pad(data, (pad_length, pad_length))
data = fft_conv1d(data, self.firs[i, None, ...])
data = self.passband_gain[i] * data
data = data[..., 1 : nt + 1]

data = data.permute(0, 2, 1)
data = data[:, :nt, :]
return data

def apply_filter(self, data: torch.Tensor):
return self.apply_filter_fir(data)

def forward(self, data: torch.Tensor):
return self.integrated_loudness(data)

def _unfold(self, input_data):
T_g = self.block_size
overlap = 0.75 # overlap of 75% of the block duration
step = 1.0 - overlap # step size by percentage

kernel_size = int(T_g * self.rate)
stride = int(T_g * self.rate * step)
unfolded = unfold(input_data.permute(0, 2, 1), kernel_size, stride)
unfolded = unfolded.transpose(-1, -2)

return unfolded

def integrated_loudness(self, data: torch.Tensor):

if not torch.is_tensor(data):
data = torch.from_numpy(data).float()
else:
data = data.float()

input_data = copy.copy(data)
# Data always has a batch and channel dimension.
# Is of shape (nb, nt, nch)
if input_data.ndim < 2:
input_data = input_data.unsqueeze(-1)
if input_data.ndim < 3:
input_data = input_data.unsqueeze(0)

nb, _, nch = input_data.shape

# Apply frequency weighting filters - account
# for the acoustic respose of the head and auditory system
input_data = self.apply_filter(input_data)

G = self.G # channel gains
T_g = self.block_size # 400 ms gating block standard
Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold

unfolded = self._unfold(input_data)

z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2)
l = -0.691 + 10.0 * torch.log10((G[None, :nch, None] * z).sum(1, keepdim=True))
l = l.expand_as(z)

# find gating block indices above absolute threshold
z_avg_gated = z
z_avg_gated[l <= Gamma_a] = 0
masked = l > Gamma_a
z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)

# calculate the relative threshold value (see eq. 6)
Gamma_r = (
-0.691 + 10.0 * torch.log10((z_avg_gated * G[None, :nch]).sum(-1)) - 10.0
)
Gamma_r = Gamma_r[:, None, None]
Gamma_r = Gamma_r.expand(nb, nch, l.shape[-1])

# find gating block indices above relative and absolute thresholds (end of eq. 7)
z_avg_gated = z
z_avg_gated[l <= Gamma_a] = 0
z_avg_gated[l <= Gamma_r] = 0
masked = (l > Gamma_a) * (l > Gamma_r)
z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)

# # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version)
# z_avg_gated = torch.nan_to_num(z_avg_gated)
z_avg_gated = torch.where(
z_avg_gated.isnan(), torch.zeros_like(z_avg_gated), z_avg_gated
)
z_avg_gated[z_avg_gated == float("inf")] = float(np.finfo(np.float32).max)
z_avg_gated[z_avg_gated == -float("inf")] = float(np.finfo(np.float32).min)

LUFS = -0.691 + 10.0 * torch.log10((G[None, :nch] * z_avg_gated).sum(1))
return LUFS.float()


def loudness(
audio_data, sample_rate: int, target_loudness: int, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
):
MIN_LOUDNESS = -70
device = audio_data.device

original_length = audio_data.shape[-1]
signal_duration = original_length / sample_rate

# Pad if too short
if signal_duration < 0.5:
pad_len = int((0.5 - signal_duration) * sample_rate)
audio_data = torch.nn.functional.pad(audio_data, (0, pad_len), mode="constant", value=0)

# create BS.1770 meter
meter = Meter(
sample_rate, filter_class=filter_class, block_size=block_size, **kwargs
)
meter = meter.to(audio_data.device)
# measure loudness
loudness = meter.integrated_loudness(audio_data.permute(0, 2, 1))
audio_data = audio_data[..., :original_length]
min_loudness = (
torch.ones_like(loudness, device=loudness.device) * MIN_LOUDNESS
)
_loudness = torch.maximum(loudness, min_loudness)

_loudness = _loudness.to(device)

delta_loudness = target_loudness - _loudness
gain = torch.pow(torch.tensor(10.0, device=device, dtype=audio_data.dtype), delta_loudness / 20.0)

output = gain * audio_data

if torch.max(torch.abs(output)) >= 1.0:
import warnings
warnings.warn("Possible clipped samples in output.")

return output
Loading