-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_fft.py
More file actions
284 lines (219 loc) · 10.5 KB
/
test_fft.py
File metadata and controls
284 lines (219 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import unittest
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
from pathlib import Path
class TestFFTAnalysis(unittest.TestCase):
"""Test suite for FFT analysis functions."""
def setUp(self):
"""Set up test fixtures."""
self.fs = 1000 # Sampling frequency
self.duration = 1 # 1 second
self.t = np.arange(0, self.duration, 1/self.fs)
self.frequencies = [5, 10, 15, 20]
def tearDown(self):
"""Clean up after tests."""
plt.close('all')
if os.path.exists('test_plot.png'):
os.remove('test_plot.png')
def test_time_vector_generation(self):
"""Test that time vector is correctly generated."""
expected_length = self.fs * self.duration
self.assertEqual(len(self.t), expected_length)
self.assertAlmostEqual(self.t[0], 0.0)
self.assertAlmostEqual(self.t[-1], self.duration - 1/self.fs)
def test_sine_wave_generation(self):
"""Test sine wave generation at specific frequency."""
f = 5 # 5 Hz
signal = np.sin(2 * np.pi * f * self.t)
# Check signal properties
self.assertEqual(len(signal), len(self.t))
self.assertTrue(np.all(np.abs(signal) <= 1.0)) # Amplitude should be 1
self.assertTrue(np.isfinite(signal).all()) # No NaN or Inf
def test_fft_computation(self):
"""Test FFT computation produces correct results."""
f = 10 # 10 Hz
signal = np.sin(2 * np.pi * f * self.t)
ft_signal = np.fft.fft(signal)
# Check FFT properties
self.assertEqual(len(ft_signal), len(signal))
self.assertTrue(np.isfinite(ft_signal).all()) # No NaN or Inf
# Check magnitude is symmetric for real signal
magnitude = np.abs(ft_signal)
self.assertGreater(magnitude.max(), 0)
def test_fft_peak_detection(self):
"""Test that FFT correctly identifies the dominant frequency."""
f = 15 # 15 Hz
signal = np.sin(2 * np.pi * f * self.t)
ft_signal = np.fft.fft(signal)
freq = np.fft.fftfreq(len(signal), d=1/self.fs)
# Find peak in positive frequencies only
positive_freqs = freq[:len(freq)//2]
magnitude = np.abs(ft_signal[:len(ft_signal)//2])
peak_idx = np.argmax(magnitude)
peak_freq = positive_freqs[peak_idx]
# Peak should be close to the input frequency
self.assertAlmostEqual(peak_freq, f, delta=1)
def test_frequency_vector_generation(self):
"""Test that frequency vector is correctly generated."""
signal = np.sin(2 * np.pi * 5 * self.t)
freq = np.fft.fftfreq(len(signal), d=1/self.fs)
# Check frequency vector properties
self.assertEqual(len(freq), len(signal))
self.assertAlmostEqual(freq[0], 0.0)
self.assertAlmostEqual(freq[1], 1/self.duration, places=2)
def test_multiple_frequencies_fft(self):
"""Test FFT on multiple frequency signals."""
for f in self.frequencies:
signal = np.sin(2 * np.pi * f * self.t)
ft_signal = np.fft.fft(signal)
freq = np.fft.fftfreq(len(signal), d=1/self.fs)
# Check each frequency has a valid FFT
self.assertEqual(len(ft_signal), len(signal))
self.assertTrue(np.isfinite(ft_signal).all())
self.assertEqual(len(freq), len(signal))
def test_plot_generation(self):
"""Test that plots can be generated without errors."""
fig, axes = plt.subplots(len(self.frequencies), 2, figsize=(12, 10))
for i, f in enumerate(self.frequencies):
signal = np.sin(2 * np.pi * f * self.t)
ft_signal = np.fft.fft(signal)
freq = np.fft.fftfreq(len(signal), d=1/self.fs)
# Plot signal
axes[i, 0].plot(self.t, signal, linewidth=1.5)
axes[i, 0].set_title(f'Signal: {f} Hz', fontweight='bold')
# Plot FFT
axes[i, 1].plot(freq, np.abs(ft_signal), linewidth=1.5)
axes[i, 1].set_title(f'FFT: Peak at {f} Hz', fontweight='bold')
self.assertEqual(len(fig.axes), len(self.frequencies) * 2)
def test_plot_save(self):
"""Test that plots can be saved to file."""
fig, axes = plt.subplots(1, 1, figsize=(8, 6))
f = 10
signal = np.sin(2 * np.pi * f * self.t)
axes.plot(self.t, signal)
output_file = 'test_plot.png'
fig.savefig(output_file, dpi=150, bbox_inches='tight')
# Check file was created
self.assertTrue(os.path.exists(output_file))
self.assertGreater(os.path.getsize(output_file), 0)
def test_signal_properties(self):
"""Test basic properties of generated signals."""
for f in self.frequencies:
signal = np.sin(2 * np.pi * f * self.t)
# Amplitude should be approximately 1
self.assertAlmostEqual(np.max(signal), 1.0, places=1)
self.assertAlmostEqual(np.min(signal), -1.0, places=1)
# Mean should be approximately 0
self.assertAlmostEqual(np.mean(signal), 0.0, places=1)
def test_nyquist_frequency(self):
"""Test Nyquist frequency consideration."""
nyquist = self.fs / 2
# Frequencies should be well below Nyquist
for f in self.frequencies:
self.assertLess(f, nyquist)
def test_fft_parseval_theorem(self):
"""Test Parseval's theorem: energy in time domain equals energy in frequency domain."""
f = 10
signal = np.sin(2 * np.pi * f * self.t)
# Energy in time domain
time_energy = np.sum(signal ** 2)
# Energy in frequency domain
ft_signal = np.fft.fft(signal)
freq_energy = np.sum(np.abs(ft_signal) ** 2) / len(signal)
# They should be approximately equal
self.assertAlmostEqual(time_energy, freq_energy, delta=10)
def test_backend_detection(self):
"""Test matplotlib backend detection logic."""
# When CI or no DISPLAY is set, should use Agg
use_non_interactive = os.environ.get('CI') or not os.environ.get('DISPLAY')
# This is just checking the logic works, not the actual backend
self.assertIsInstance(use_non_interactive, bool)
class TestFFTEdgeCases(unittest.TestCase):
"""Test edge cases and error handling."""
def setUp(self):
"""Set up test fixtures."""
self.fs = 1000
self.duration = 1
self.t = np.arange(0, self.duration, 1/self.fs)
def tearDown(self):
"""Clean up after tests."""
plt.close('all')
def test_zero_signal_fft(self):
"""Test FFT on zero signal."""
signal = np.zeros(len(self.t))
ft_signal = np.fft.fft(signal)
# FFT of zero signal should be zero
self.assertTrue(np.allclose(ft_signal, 0))
def test_constant_signal_fft(self):
"""Test FFT on constant signal."""
signal = np.ones(len(self.t)) * 5 # Constant value of 5
ft_signal = np.fft.fft(signal)
# DC component (frequency 0) should be present
self.assertGreater(np.abs(ft_signal[0]), 0)
def test_high_frequency_signal(self):
"""Test signal with frequency close to Nyquist."""
nyquist = self.fs / 2
f = nyquist - 10
signal = np.sin(2 * np.pi * f * self.t)
ft_signal = np.fft.fft(signal)
# Should still compute without error
self.assertEqual(len(ft_signal), len(signal))
self.assertTrue(np.isfinite(ft_signal).all())
def test_empty_frequency_list(self):
"""Test behavior with empty frequency list."""
frequencies = []
fig, axes = plt.subplots(len(frequencies) or 1, 2, figsize=(12, 10))
if len(frequencies) == 0:
# With no frequencies, should create a single subplot
self.assertEqual(len(fig.axes), 2)
class TestFFTIntegration(unittest.TestCase):
"""Integration tests for the complete FFT workflow."""
def setUp(self):
"""Set up test fixtures."""
self.fs = 1000
self.duration = 1
self.t = np.arange(0, self.duration, 1/self.fs)
self.frequencies = [5, 10, 15, 20]
self.colors = ['blue', 'orange', 'green', 'red']
def tearDown(self):
"""Clean up after tests."""
plt.close('all')
if os.path.exists('integration_test_plot.png'):
os.remove('integration_test_plot.png')
def test_complete_fft_workflow(self):
"""Test the complete FFT analysis workflow."""
fig, axes = plt.subplots(len(self.frequencies), 2, figsize=(12, 10))
fig.suptitle('FFT Analysis of Multiple Sine Waves', fontsize=16, fontweight='bold')
# Process each frequency
for i, (f, color) in enumerate(zip(self.frequencies, self.colors)):
# Create a signal: sine wave
signal = np.sin(2 * np.pi * f * self.t)
# Compute the Fourier transform
ft_signal = np.fft.fft(signal)
freq = np.fft.fftfreq(len(signal), d=1/self.fs)
# Plot the original signal
axes[i, 0].plot(self.t, signal, color=color, linewidth=1.5)
axes[i, 0].set_title(f'Signal: {f} Hz', fontweight='bold')
axes[i, 0].set_xlabel('Time [s]')
axes[i, 0].set_ylabel('Amplitude')
axes[i, 0].grid(True, alpha=0.3)
axes[i, 0].set_xlim(0, 0.5)
# Plot the Fourier transform
axes[i, 1].plot(freq, np.abs(ft_signal), color=color, linewidth=1.5)
axes[i, 1].set_title(f'FFT: Peak at {f} Hz', fontweight='bold')
axes[i, 1].set_xlabel('Frequency [Hz]')
axes[i, 1].set_ylabel('Magnitude')
axes[i, 1].set_xlim(0, 50)
axes[i, 1].grid(True, alpha=0.3)
plt.tight_layout()
# Verify the plot was created successfully
self.assertEqual(len(fig.axes), len(self.frequencies) * 2)
# Save and verify
output_file = 'integration_test_plot.png'
fig.savefig(output_file, dpi=150, bbox_inches='tight')
self.assertTrue(os.path.exists(output_file))
self.assertGreater(os.path.getsize(output_file), 0)
if __name__ == '__main__':
unittest.main()