Skip to content

Commit 9e9192d

Browse files
committed
Implement the DCT function using numpy
1 parent 2fb5afa commit 9e9192d

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

python/rapid_paraformer/kaldifeat/feature.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
from scipy.fftpack import dct
32

43

54
# ---------- feature-window ----------
@@ -137,6 +136,28 @@ def extract_window(waveform, blackman_coeff, dither, window_size, window_shift,
137136

138137
# ---------- feature-functions ----------
139138

139+
def dct_np(x: np.array, norm=None):
140+
x_shape = x.shape
141+
N = x_shape[-1]
142+
143+
v = np.hstack([x[:, ::2], x[:, 1::2][:, ::-1]])
144+
Vc = np.fft.fft(v)
145+
Vc = np.dstack([np.real(Vc), np.imag(Vc)])
146+
147+
k = - np.arange(N, dtype=x.dtype)[None, :] * np.pi / (2*N)
148+
W_r = np.cos(k)
149+
W_i = np.sin(k)
150+
151+
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
152+
153+
if norm == "ortho":
154+
V[:, 0] /= np.sqrt(N) * 2
155+
V[:, 1:] /= np.sqrt(N/2) * 2
156+
157+
V = 2 * V
158+
return V
159+
160+
140161
def compute_spectrum(frames, n):
141162
complex_spec = np.fft.rfft(frames, n)
142163
return np.absolute(complex_spec)
@@ -424,7 +445,7 @@ def compute_mfcc_feats(
424445
window_type=window_type,
425446
dtype=dtype
426447
)
427-
feat = dct(feat, type=2, axis=1, norm='ortho')[:, :num_ceps]
448+
feat = dct_np(feat, norm="ortho")[:, :num_ceps]
428449
lifter_coeffs = compute_lifter_coeffs(cepstral_lifter, num_ceps).astype(dtype)
429450
feat = feat * lifter_coeffs
430451
if use_energy:

0 commit comments

Comments
 (0)