Skip to content

Commit 254622d

Browse files
committed
init
1 parent 27870ec commit 254622d

19 files changed

+1260506
-9
lines changed

comfy/autoregressive_sampling.py

Lines changed: 632 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
import torch.nn as nn
3+
from typing import Optional, Dict
4+
import gc
5+
6+
_NUM_WARMUP_ITERS = 2
7+
8+
class CUDAGraphRunner(nn.Module):
9+
def __init__(self, model):
10+
super().__init__()
11+
self.model = model
12+
13+
self.input_buffers: Dict[str, torch.Tensor] = {}
14+
self.output_buffers: Dict[str, torch.Tensor] = {}
15+
16+
self._graph: Optional[torch.cuda.CUDAGraph] = None
17+
18+
@property
19+
def graph(self):
20+
assert self._graph is not None
21+
return self._graph
22+
23+
def capture(self, *args, **kwargs):
24+
assert self._graph is None
25+
26+
for _ in range(_NUM_WARMUP_ITERS):
27+
self.model(*args, **kwargs)
28+
29+
torch.cuda.synchronize()
30+
31+
self._graph = torch.cuda.CUDAGraph()
32+
with torch.cuda.graph(self._graph, pool = kwargs.get("memory_pool", None), stream = kwargs.get("stream", None)):
33+
last_hidden_states = self.model(*args, **kwargs)
34+
gc.collect()
35+
36+
torch.cuda.synchronize()
37+
38+
self.input_buffers = {
39+
"args": [arg for arg in args if isinstance(arg, torch.Tensor)],
40+
"kwargs": {k: v for k, v in kwargs.items() if isinstance(v, torch.Tensor)},
41+
}
42+
43+
self.output_buffers = {
44+
"hidden_states": last_hidden_states
45+
}
46+
47+
def forward(self, *args, **kwargs):
48+
49+
for i, arg in enumerate(args):
50+
if isinstance(arg, torch.Tensor):
51+
self.input_buffers["args"][i].copy_(arg, non_blocking=True)
52+
53+
for k, v in kwargs.items():
54+
if k in self.input_buffers["kwargs"] and isinstance(v, torch.Tensor):
55+
self.input_buffers["kwargs"][k].copy_(v, non_blocking=True)
56+
57+
self.graph.replay()
58+
59+
return self.output_buffers["hidden_states"]

comfy/ldm/higgsv2/loudness.py

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
import copy
2+
import math
3+
import torch
4+
import scipy
5+
import torchaudio
6+
import numpy as np
7+
import torch.nn.functional as F
8+
from typing import Optional, List
9+
10+
# defaulted to the new pytorch api
11+
def _new_rfft(x: torch.Tensor):
12+
z = torch.fft.rfft(x, dim=-1)
13+
return torch.view_as_real(z)
14+
15+
def _new_irfft(x: torch.Tensor, length: int):
16+
x = torch.view_as_complex(x)
17+
return torch.fft.irfft(x, length, dim=-1)
18+
19+
def _compl_mul_conjugate(a: torch.Tensor, b: torch.Tensor):
20+
# changed this function to use the pytorch api
21+
return torch.view_as_real(torch.view_as_complex(a) * torch.view_as_complex(b).conj())
22+
23+
def unfold(input, kernel_size: int, stride: int):
24+
25+
shape = list(input.shape)
26+
length = shape.pop(-1)
27+
28+
n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1
29+
tgt_length = (n_frames - 1) * stride + kernel_size
30+
31+
padded = F.pad(input, (0, tgt_length - length)).contiguous()
32+
strides: List[int] = []
33+
34+
for dim in range(padded.dim()):
35+
strides.append(padded.stride(dim))
36+
37+
last_stride = strides.pop(-1)
38+
assert last_stride == 1, 'data should be contiguous'
39+
40+
strides = strides + [stride, 1]
41+
return padded.as_strided(shape + [n_frames, kernel_size], strides)
42+
43+
# convert the signal and filter to frequency domain, multiply them, then inverse FFT to get back to time-domain
44+
# faster than a sliding window over time-domain.
45+
def fft_conv1d(
46+
input: torch.Tensor, weight: torch.Tensor,
47+
bias: Optional[torch.Tensor] = None, stride: int = 1, padding: int = 0,
48+
block_ratio: float = 5):
49+
50+
input = F.pad(input, (padding, padding))
51+
batch, _, length = input.shape
52+
out_channels, _, kernel_size = weight.shape
53+
54+
_rfft = _new_rfft
55+
_irfft = _new_irfft
56+
57+
if length < kernel_size:
58+
raise RuntimeError(f"Input should be at least as large as the kernel size {kernel_size}, "
59+
f"but it is only {length} samples long.")
60+
if block_ratio < 1:
61+
raise RuntimeError("Block ratio must be greater than 1.")
62+
63+
# We are going to process the input blocks by blocks, as for some reason it is faster
64+
# and less memory intensive (I think the culprit is `torch.einsum`.
65+
block_size: int = min(int(kernel_size * block_ratio), length)
66+
fold_stride = block_size - kernel_size + 1
67+
68+
# replaces to_pad
69+
weight = F.pad(weight, (0, block_size - weight.shape[-1]), mode = "constant", value = 0)
70+
weight_z = _rfft(weight)
71+
72+
# We pad the input and get the different frames, on which
73+
frames = unfold(input, block_size, fold_stride)
74+
75+
frames_z = _rfft(frames)
76+
out_z = _compl_mul_conjugate(frames_z, weight_z)
77+
out = _irfft(out_z, block_size)
78+
# The last bit is invalid, because FFT will do a circular convolution.
79+
out = out[..., :-kernel_size + 1]
80+
out = out.reshape(batch, out_channels, -1)
81+
out = out[..., ::stride]
82+
target_length = (length - kernel_size) // stride + 1
83+
out = out[..., :target_length]
84+
if bias is not None:
85+
out += bias[:, None]
86+
return out
87+
88+
class IIRfilter(object):
89+
90+
def __init__(self, G, Q, fc, rate, filter_type, passband_gain=1.0):
91+
self.G = G
92+
self.Q = Q
93+
self.fc = fc
94+
self.rate = rate
95+
self.filter_type = filter_type
96+
self.passband_gain = passband_gain
97+
98+
def generate_coefficients(self):
99+
100+
A = 10**(self.G/40.0)
101+
w0 = 2.0 * np.pi * (self.fc / self.rate)
102+
alpha = np.sin(w0) / (2.0 * self.Q)
103+
104+
if self.filter_type == 'high_shelf':
105+
b0 = A * ( (A+1) + (A-1) * np.cos(w0) + 2 * np.sqrt(A) * alpha )
106+
b1 = -2 * A * ( (A-1) + (A+1) * np.cos(w0) )
107+
b2 = A * ( (A+1) + (A-1) * np.cos(w0) - 2 * np.sqrt(A) * alpha )
108+
a0 = (A+1) - (A-1) * np.cos(w0) + 2 * np.sqrt(A) * alpha
109+
a1 = 2 * ( (A-1) - (A+1) * np.cos(w0) )
110+
a2 = (A+1) - (A-1) * np.cos(w0) - 2 * np.sqrt(A) * alpha
111+
112+
elif self.filter_type == 'high_pass':
113+
b0 = (1 + np.cos(w0))/2
114+
b1 = -(1 + np.cos(w0))
115+
b2 = (1 + np.cos(w0))/2
116+
a0 = 1 + alpha
117+
a1 = -2 * np.cos(w0)
118+
a2 = 1 - alpha
119+
120+
return np.array([b0, b1, b2])/a0, np.array([a0, a1, a2])/a0
121+
122+
def apply_filter(self, data):
123+
return self.passband_gain * scipy.signal.lfilter(self.b, self.a, data)
124+
125+
@property
126+
def b_and_a(self):
127+
return self.generate_coefficients()
128+
129+
class Meter(torch.nn.Module):
130+
131+
def __init__(
132+
self,
133+
rate: int,
134+
filter_class: str = "K-weighting",
135+
block_size: float = 0.400,
136+
zeros: int = 512,
137+
use_fir: bool = False,
138+
):
139+
super().__init__()
140+
141+
self.rate = rate
142+
self.filter_class = filter_class
143+
self.block_size = block_size
144+
self.use_fir = use_fir
145+
146+
G = torch.from_numpy(np.array([1.0, 1.0, 1.0, 1.41, 1.41]))
147+
self.register_buffer("G", G)
148+
149+
self._filters = {}
150+
self._filters['high_shelf'] = IIRfilter(4.0, 1/np.sqrt(2), 1500.0, self.rate, 'high_shelf')
151+
self._filters['high_pass'] = IIRfilter(0.0, 0.5, 38.0, self.rate, 'high_pass')
152+
153+
# Compute impulse responses so that filtering is fast via
154+
# a convolution at runtime, on GPU, unlike lfilter.
155+
impulse = np.zeros((zeros,))
156+
impulse[..., 0] = 1.0
157+
158+
firs = np.zeros((len(self._filters), 1, zeros))
159+
passband_gain = torch.tensor([filter.passband_gain for filter in self._filters.values()])
160+
161+
for i, (_, filter_stage) in enumerate(self._filters.items()):
162+
b, a = filter_stage.b_and_a
163+
firs[i] = scipy.signal.lfilter(b, a, impulse)
164+
165+
firs = torch.from_numpy(firs[..., ::-1].copy()).float()
166+
167+
self.register_buffer("firs", firs)
168+
self.register_buffer("passband_gain", passband_gain)
169+
170+
def apply_filter_gpu(self, data: torch.Tensor):
171+
172+
# Data is of shape (nb, nch, nt)
173+
# Reshape to (nb*nch, 1, nt)
174+
nb, nt, nch = data.shape
175+
data = data.permute(0, 2, 1)
176+
data = data.reshape(nb * nch, 1, nt)
177+
178+
# Apply padding
179+
pad_length = self.firs.shape[-1]
180+
181+
# Apply filtering in sequence
182+
for i in range(self.firs.shape[0]):
183+
data = F.pad(data, (pad_length, pad_length))
184+
data = fft_conv1d(data, self.firs[i, None, ...])
185+
data = self.passband_gain[i] * data
186+
data = data[..., 1 : nt + 1]
187+
188+
data = data.permute(0, 2, 1)
189+
data = data[:, :nt, :]
190+
return data
191+
192+
def apply_filter_cpu(self, data: torch.Tensor):
193+
for _, filter_stage in self._filters.items():
194+
passband_gain = filter_stage.passband_gain
195+
b, a = filter_stage.b_and_a
196+
197+
a_coeffs = torch.from_numpy(a).float().to(data.device)
198+
b_coeffs = torch.from_numpy(b).float().to(data.device)
199+
200+
_data = data.permute(0, 2, 1)
201+
filtered = torchaudio.functional.lfilter(
202+
_data, a_coeffs, b_coeffs, clamp=False
203+
)
204+
data = passband_gain * filtered.permute(0, 2, 1)
205+
return data
206+
207+
def apply_filter(self, data: torch.Tensor):
208+
if data.is_cuda or self.use_fir:
209+
data = self.apply_filter_gpu(data)
210+
else:
211+
data = self.apply_filter_cpu(data)
212+
return data
213+
214+
def forward(self, data: torch.Tensor):
215+
return self.integrated_loudness(data)
216+
217+
def _unfold(self, input_data):
218+
T_g = self.block_size
219+
overlap = 0.75 # overlap of 75% of the block duration
220+
step = 1.0 - overlap # step size by percentage
221+
222+
kernel_size = int(T_g * self.rate)
223+
stride = int(T_g * self.rate * step)
224+
unfolded = unfold(input_data.permute(0, 2, 1), kernel_size, stride)
225+
unfolded = unfolded.transpose(-1, -2)
226+
227+
return unfolded
228+
229+
def integrated_loudness(self, data: torch.Tensor):
230+
231+
if not torch.is_tensor(data):
232+
data = torch.from_numpy(data).float()
233+
else:
234+
data = data.float()
235+
236+
input_data = copy.copy(data)
237+
# Data always has a batch and channel dimension.
238+
# Is of shape (nb, nt, nch)
239+
if input_data.ndim < 2:
240+
input_data = input_data.unsqueeze(-1)
241+
if input_data.ndim < 3:
242+
input_data = input_data.unsqueeze(0)
243+
244+
nb, _, nch = input_data.shape
245+
246+
# Apply frequency weighting filters - account
247+
# for the acoustic respose of the head and auditory system
248+
input_data = self.apply_filter(input_data)
249+
250+
G = self.G # channel gains
251+
T_g = self.block_size # 400 ms gating block standard
252+
Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold
253+
254+
unfolded = self._unfold(input_data)
255+
256+
z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2)
257+
l = -0.691 + 10.0 * torch.log10((G[None, :nch, None] * z).sum(1, keepdim=True))
258+
l = l.expand_as(z)
259+
260+
# find gating block indices above absolute threshold
261+
z_avg_gated = z
262+
z_avg_gated[l <= Gamma_a] = 0
263+
masked = l > Gamma_a
264+
z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
265+
266+
# calculate the relative threshold value (see eq. 6)
267+
Gamma_r = (
268+
-0.691 + 10.0 * torch.log10((z_avg_gated * G[None, :nch]).sum(-1)) - 10.0
269+
)
270+
Gamma_r = Gamma_r[:, None, None]
271+
Gamma_r = Gamma_r.expand(nb, nch, l.shape[-1])
272+
273+
# find gating block indices above relative and absolute thresholds (end of eq. 7)
274+
z_avg_gated = z
275+
z_avg_gated[l <= Gamma_a] = 0
276+
z_avg_gated[l <= Gamma_r] = 0
277+
masked = (l > Gamma_a) * (l > Gamma_r)
278+
z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
279+
280+
# # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version)
281+
# z_avg_gated = torch.nan_to_num(z_avg_gated)
282+
z_avg_gated = torch.where(
283+
z_avg_gated.isnan(), torch.zeros_like(z_avg_gated), z_avg_gated
284+
)
285+
z_avg_gated[z_avg_gated == float("inf")] = float(np.finfo(np.float32).max)
286+
z_avg_gated[z_avg_gated == -float("inf")] = float(np.finfo(np.float32).min)
287+
288+
LUFS = -0.691 + 10.0 * torch.log10((G[None, :nch] * z_avg_gated).sum(1))
289+
return LUFS.float()
290+
291+
292+
def loudness(
293+
audio_data, sample_rate: int, target_loudness: int, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
294+
):
295+
MIN_LOUDNESS = -70
296+
device = audio_data.device
297+
298+
original_length = audio_data.shape[-1]
299+
signal_duration = original_length / sample_rate
300+
301+
# Pad if too short
302+
if signal_duration < 0.5:
303+
pad_len = int((0.5 - signal_duration) * sample_rate)
304+
audio_data = torch.nn.functional.pad(audio_data, (0, pad_len), mode="constant", value=0)
305+
306+
# create BS.1770 meter
307+
meter = Meter(
308+
sample_rate, filter_class=filter_class, block_size=block_size, **kwargs
309+
)
310+
meter = meter.to(audio_data.device)
311+
# measure loudness
312+
loudness = meter.integrated_loudness(audio_data.permute(0, 2, 1))
313+
audio_data = audio_data[..., :original_length]
314+
min_loudness = (
315+
torch.ones_like(loudness, device=loudness.device) * MIN_LOUDNESS
316+
)
317+
_loudness = torch.maximum(loudness, min_loudness)
318+
319+
_loudness = _loudness.to(device)
320+
321+
delta_loudness = target_loudness - _loudness
322+
gain = torch.pow(torch.tensor(10.0, device=device, dtype=audio_data.dtype), delta_loudness / 20.0)
323+
324+
output = gain * audio_data
325+
326+
if torch.max(torch.abs(output)) >= 1.0:
327+
import warnings
328+
warnings.warn("Possible clipped samples in output.")
329+
330+
return output

0 commit comments

Comments
 (0)