Skip to content

Commit b8e098c

Browse files
committed
changed tests to not have matplotlib
1 parent dd60041 commit b8e098c

File tree

6 files changed

+534
-241
lines changed

6 files changed

+534
-241
lines changed

doc/source/notebooks/chirpexample.ipynb

Lines changed: 4 additions & 15 deletions
Large diffs are not rendered by default.

examples/chirpexamples.py

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
"""
2+
Create a chirp and test the PFB.
3+
4+
"""
5+
6+
from pathlib import Path
7+
import numpy as np
8+
import scipy.signal as sig
9+
from mitarspysigproc import (
10+
pfb_decompose,
11+
pfb_reconstruct,
12+
kaiser_coeffs,
13+
kaiser_syn_coeffs,
14+
npr_analysis,
15+
npr_synthesis,
16+
rref_coef,
17+
)
18+
import matplotlib.pyplot as plt
19+
20+
21+
def create_chirp(t_len, fs, bw, pad, nchans, nslice):
22+
"""Creates a chirp signal
23+
24+
Parameters
25+
----------
26+
t_len : float
27+
Length of chirp in seconds
28+
fs : float
29+
Sampling frequency in Hz
30+
bw : float
31+
Bandwidth of chirp
32+
nzeros : tuple
33+
Number of zeros to pad in the begining and end of the array.
34+
nchans : int
35+
Number of channels for the PFB
36+
nslice : int
37+
Number of time samples from the pfb
38+
39+
Returns
40+
-------
41+
tout : ndarray
42+
The time vector for the created signal
43+
xout : ndarray
44+
Created signal
45+
"""
46+
nar = (
47+
np.arange(int(-nslice * nchans / 2), int(nslice * nchans / 2), dtype=float)
48+
/ nslice
49+
/ nchans
50+
)
51+
t = np.linspace(-t_len / 2, t_len / 2, int(t_len * fs))
52+
dphi = 2 * np.pi * nar * bw / fs
53+
phi = np.mod(np.cumsum(dphi), 2 * np.pi)
54+
x = np.exp(-1j * phi)
55+
# x = sig.chirp(t,t1=t_len,f0=0,f1=bw,method='linear')
56+
57+
xout = np.concatenate((pad[0], x, pad[1]), axis=0)
58+
tp1 = -1 * np.arange(0, len(pad[0]), dtype=float)[::-1] / fs - t_len / 2
59+
tp2 = np.arange(0, len(pad[1]), dtype=float) / fs + t_len / 2
60+
tout = np.concatenate((tp1, t, tp2), axis=0)
61+
62+
return tout, xout
63+
64+
65+
def runchirptest(t_len, fs, bw, nzeros, nchans, nslice):
66+
"""Creates a chirp and runs the standard PFB analysis and reconstruction
67+
68+
Parameters
69+
----------
70+
t_len : float
71+
Length of chirp in seconds
72+
fs : float
73+
Sampling frequency in Hz
74+
bw : float
75+
Bandwidth of chirp
76+
nzeros : int
77+
Number of zeros to pad
78+
nchans : int
79+
Number of channels for the PFB
80+
nslice : int
81+
Number of time samples from the pfb
82+
83+
Returns
84+
-------
85+
x_rec : ndarray
86+
Reconstructed signal
87+
tin : ndarray
88+
The time vector for the input signal
89+
x : ndarray
90+
Input signal
91+
x_pfb : ndarray
92+
The result of the PFB analysis in an nchans x slice array.
93+
"""
94+
pad = [np.zeros(nzeros), np.zeros(nzeros)]
95+
t, x = create_chirp(t_len, fs, bw, pad, nchans, nslice)
96+
97+
coeffs = kaiser_coeffs(nchans, 8.0)
98+
mask = np.ones(nchans, dtype=bool)
99+
xout = pfb_decompose(x, nchans, coeffs, mask)
100+
fillmethod = ""
101+
fillparams = [0, 0]
102+
syn_coeffs = kaiser_syn_coeffs(nchans, 8)
103+
x_rec = pfb_reconstruct(
104+
xout, nchans, syn_coeffs, mask, fillmethod, fillparams=[], realout=False
105+
)
106+
return x_rec, t, x, xout
107+
108+
109+
def runnprchirptest(t_len, fs, bw, nzeros, nchans, nslice, ntaps=64):
110+
"""Creates a chirp and runs the near perfect PFB analysis and reconstruction
111+
112+
Parameters
113+
----------
114+
t_len : float
115+
Length of chirp in seconds
116+
fs : float
117+
Sampling frequency in Hz
118+
bw : float
119+
Bandwidth of chirp
120+
nchans : int
121+
Number of channels for the PFB
122+
nslice : int
123+
Number of time samples from the pfb
124+
125+
Returns
126+
-------
127+
x_rec : ndarray
128+
Reconstructed signal
129+
tin : ndarray
130+
The time vector for the input signal
131+
x : ndarray
132+
Input signal
133+
x_pfb : ndarray
134+
The result of the PFB analysis in an nchans x slice array.
135+
"""
136+
pad = [np.zeros(nzeros), np.zeros(nzeros)]
137+
t, x = create_chirp(t_len, fs, bw, pad, nchans, nslice)
138+
coeffs = rref_coef(nchans, ntaps)
139+
mask = np.ones(nchans, dtype=bool)
140+
xout = npr_analysis(x, nchans, coeffs)
141+
fillmethod = ""
142+
fillparams = [0, 0]
143+
x_rec = npr_synthesis(xout, nchans, coeffs)
144+
return x_rec, t, x, xout
145+
146+
147+
def nexpow2(x):
148+
"""Returns the next power of two.
149+
150+
Parameters
151+
----------
152+
x : int
153+
Inital number.
154+
155+
Returns
156+
-------
157+
int
158+
The next power of two of x.
159+
"""
160+
161+
return int(np.power(2, np.ceil(np.log2(x))))
162+
163+
164+
def plotdata(x, x_rec, tin, tout, g_del=0):
165+
"""Plot the data and return the figure.
166+
167+
Parameters
168+
----------
169+
x : ndarray
170+
Input signal
171+
x_rec : ndarray
172+
Reconstructed signal
173+
tin : ndarray
174+
The time vector for the input signal
175+
tout : ndarray
176+
The time vector for the output signal
177+
178+
Returns
179+
-------
180+
fig : matplotlib.fig
181+
The matplotlib fig for plotting or saving.
182+
"""
183+
184+
fig, ax = plt.subplots(3, 1, figsize=(10, 5))
185+
186+
inlen = x.shape[0]
187+
outlen = x_rec.shape[0]
188+
tau = tin[1] - tin[0]
189+
190+
ax[0].plot(tin, x.real, label="Input")
191+
ax[0].plot(tout, np.roll(x_rec.real, -g_del), label="Output")
192+
193+
ax[0].set_xlabel("Time in Seconds")
194+
ax[0].set_ylabel("Amplitude")
195+
ax[0].set_title("Time Domain Real Part")
196+
ax[0].grid(True)
197+
198+
ax[1].plot(tin, x.imag, label="Input")
199+
ax[1].plot(tout, np.roll(x_rec.imag, -g_del), label="Output")
200+
201+
ax[1].set_xlabel("Time in Seconds")
202+
ax[1].set_ylabel("Amplitude")
203+
ax[1].set_title("Time Domain Imaginary Part")
204+
ax[1].grid(True)
205+
nfft_in = nexpow2(inlen)
206+
nfft_out = nexpow2(outlen)
207+
208+
in_freq = np.fft.fftshift(np.fft.fftfreq(nfft_in, d=tau))
209+
out_freq = np.fft.fftshift(np.fft.fftfreq(nfft_out, d=tau))
210+
211+
spec_in = np.abs(np.fft.fftshift(np.fft.fft(x, n=nfft_in))) ** 2
212+
spec_out = np.abs(np.fft.fftshift(np.fft.fft(x_rec[:, 0], n=nfft_out))) ** 2
213+
214+
spec_in_log = 10 * np.log10(spec_in)
215+
spec_out_log = 10 * np.log10(spec_out)
216+
217+
ax[2].plot(in_freq, spec_in_log, label="Input")
218+
ax[2].plot(out_freq, spec_out_log, label="Output")
219+
220+
ax[2].set_xlabel("Frequency in Hz")
221+
ax[2].set_ylabel("Amp dB")
222+
ax[2].set_title("Frequency Content")
223+
ax[2].grid(True)
224+
ax[2].set_ylim([0, 60])
225+
fig.tight_layout()
226+
return fig
227+
228+
229+
def plot_spectrogram(x, x_rec, x_pfb):
230+
"""Plots the input signal, pfb output and reconstructed signal.
231+
232+
Parameters
233+
----------
234+
x : ndarray
235+
Input signal
236+
x_rec : ndarray
237+
Reconstructed signal
238+
x_pfb : ndarray
239+
The result of the PFB analysis in an nchans x slice array.
240+
241+
Returns
242+
-------
243+
fig : matplotlib.fig
244+
The matplotlib fig for plotting or saving.
245+
"""
246+
fig, ax = plt.subplots(1, 3, figsize=(12, 3.5))
247+
nfft = 256
248+
w = sig.get_window("blackman", nfft)
249+
SFT = sig.ShortTimeFFT(
250+
w, hop=nfft, fs=10000, mfft=nfft, scale_to="magnitude", fft_mode="centered"
251+
)
252+
253+
sxin = 20 * np.log10(np.abs((SFT.stft(x))) + 1e-12)
254+
sxout = 20 * np.log10(np.abs((SFT.stft(x_rec))) + 1e-12)
255+
256+
im1 = ax[0].imshow(
257+
sxin[::-1],
258+
origin="lower",
259+
aspect="auto",
260+
extent=SFT.extent(len(x)),
261+
cmap="viridis",
262+
vmin=-50,
263+
vmax=0,
264+
)
265+
im2 = ax[1].imshow(
266+
sxout[::-1],
267+
origin="lower",
268+
aspect="auto",
269+
extent=SFT.extent(len(x_rec)),
270+
cmap="viridis",
271+
vmin=-50,
272+
vmax=0,
273+
)
274+
275+
nchan, nslice = x_pfb.shape
276+
x_pfb = np.fft.fftshift(x_pfb / np.abs(x_pfb.flatten()).max(), axes=0)
277+
x_pfb_db = 20 * np.log10(np.abs(x_pfb) + 1e-12)
278+
im3 = ax[2].imshow(
279+
x_pfb_db,
280+
origin="lower",
281+
aspect="auto",
282+
extent=SFT.extent(len(x_rec)),
283+
cmap="viridis",
284+
vmin=-50,
285+
vmax=0,
286+
)
287+
fig.tight_layout()
288+
fig.subplots_adjust(right=0.8)
289+
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
290+
fig.colorbar(im3, cax=cbar_ax)
291+
292+
return fig
293+
294+
def runexample():
295+
"""Function for running each of the examples."""
296+
nchans = 64
297+
nslice = 2048
298+
fs = 10000
299+
t_len = nchans * nslice / fs
300+
bw = 2000
301+
ntaps = 64
302+
g_del = nchans * (ntaps - 1) // 2
303+
nzeros = 2048
304+
305+
x_rec, t, x, xpfb = runchirptest(t_len, fs, bw, nzeros*2, nchans, nslice)
306+
307+
fig = plotdata(x, x_rec[:t.shape[0],:], t, t, nchans*ntaps)
308+
fig.savefig("chirpdata.png")
309+
plt.close(fig)
310+
fig2 = plot_spectrogram(x, x_rec[:, 0], xpfb[:, :, 0])
311+
fig2.savefig("chirpspecgrams.png")
312+
plt.close(fig2)
313+
314+
x_rec, t, x, xpfb = runnprchirptest(t_len, fs, bw, nzeros, nchans, nslice, ntaps)
315+
x_rec = x_rec[: len(x), np.newaxis] # need to add new axis due to plotting issue
316+
317+
fig = plotdata(x, x_rec, t, t, g_del)
318+
fig.savefig("chirpdatanpr.png")
319+
plt.close(fig)
320+
321+
fig2 = plot_spectrogram(x, x_rec[:, 0], xpfb)
322+
fig2.savefig("nprchirpspecgrams.png")
323+
plt.close(fig2)
324+
325+
326+
if __name__ == "__main__":
327+
runexample()

0 commit comments

Comments
 (0)