|
15 | 15 | # ------------------------------- |
16 | 16 | # We start by importing the necessary modules and setting up the environment. |
17 | 17 |
|
18 | | -import matplotlib.pyplot as plt |
19 | 18 | import numpy as np |
20 | 19 | import torch |
21 | 20 |
|
| 21 | +from examples.utils.plotting import ( |
| 22 | + setup_plotting_style, |
| 23 | + plot_signal_noise_comparison, |
| 24 | + plot_snr_psnr_comparison, |
| 25 | + plot_snr_vs_mse, |
| 26 | + plot_noise_level_analysis |
| 27 | +) |
| 28 | + |
22 | 29 | from kaira.channels import AWGNChannel |
23 | 30 | from kaira.metrics.image import PSNR |
24 | 31 | from kaira.metrics.signal import SNR |
|
28 | 35 | torch.manual_seed(42) |
29 | 36 | np.random.seed(42) |
30 | 37 |
|
| 38 | +# Configure plotting style |
| 39 | +setup_plotting_style() |
| 40 | + |
31 | 41 | # %% |
32 | 42 | # Create Sample Signal |
33 | 43 | # ------------------------------------ |
|
93 | 103 | measured_metrics.append({"target_snr_db": snr_db, "measured_snr_db": measured_snr, "measured_psnr_db": measured_psnr}) |
94 | 104 |
|
95 | 105 | # Ensure we're using float values for string formatting |
96 | | - print(f"Target SNR: {snr_db:.1f} dB, Measured SNR: {measured_snr:.1f} dB, PSNR: {measured_psnr:.1f} dB") |
| 106 | + # AWGN Channel Performance Analysis |
| 107 | + # =============================== |
| 108 | + # Target SNR: {snr_db:.1f} dB, Measured SNR: {measured_snr:.1f} dB, PSNR: {measured_psnr:.1f} dB |
97 | 109 |
|
98 | 110 | # %% |
99 | 111 | # Visualize the Results |
100 | 112 | # ------------------------------------- |
101 | 113 | # Let's visualize how different SNR levels affect the transmitted signal. |
102 | 114 |
|
103 | | -plt.figure(figsize=(10, 8)) |
104 | | - |
105 | | -# Plot the original signal |
106 | | -plt.subplot(len(snr_levels_db) + 1, 1, 1) |
107 | | -plt.plot(t, signal, "b-", linewidth=1.5) |
108 | | -plt.title("Original Signal") |
109 | | -plt.grid(True) |
110 | | -plt.ylabel("Amplitude") |
111 | | -plt.xlim([0, 1]) |
112 | | - |
113 | | -# Plot each noisy signal |
114 | | -for i, (snr_db, output) in enumerate(outputs): |
115 | | - plt.subplot(len(snr_levels_db) + 1, 1, i + 2) |
116 | | - plt.plot(t, output, "r-", alpha=0.8) |
117 | | - measured_snr = measured_metrics[i]["measured_snr_db"] |
118 | | - plt.title(f"AWGN Channel (Target SNR = {snr_db} dB, Measured SNR = {measured_snr:.1f} dB)") |
119 | | - plt.grid(True) |
120 | | - plt.ylabel("Amplitude") |
121 | | - if i == len(outputs) - 1: |
122 | | - plt.xlabel("Time (s)") |
123 | | - plt.xlim([0, 1]) |
124 | | - |
125 | | -plt.tight_layout() |
126 | | -plt.show() |
| 115 | +# Visualization: Signal Degradation with Noise |
| 116 | +# ============================================ |
| 117 | +# Compare the original clean signal with signals processed through |
| 118 | +# AWGN channels at different SNR levels to observe noise effects. |
| 119 | + |
| 120 | +fig = plot_signal_noise_comparison(t, signal, outputs, measured_metrics, |
| 121 | + "AWGN Channel Effects on Signal Transmission") |
| 122 | +fig.show() |
127 | 123 |
|
128 | 124 | # %% |
129 | 125 | # Compare Theoretical and Measured SNR Values |
130 | 126 | # ------------------------------------------------------------------------------ |
131 | 127 | # Let's compare the target SNR values with what we actually measured. |
132 | 128 |
|
133 | | -plt.figure(figsize=(10, 5)) |
134 | | - |
135 | | -target_snrs = [metric["target_snr_db"] for metric in measured_metrics] |
136 | | -measured_snrs = [metric["measured_snr_db"] for metric in measured_metrics] |
137 | | -measured_psnrs = [metric["measured_psnr_db"] for metric in measured_metrics] |
138 | | - |
139 | | -# Plot SNR comparison |
140 | | -plt.subplot(1, 2, 1) |
141 | | -plt.plot(target_snrs, measured_snrs, "bo-", linewidth=2, label="Measured SNR") |
142 | | -plt.plot(target_snrs, target_snrs, "k--", linewidth=1, label="Theoretical (Target)") |
143 | | -plt.grid(True) |
144 | | -plt.xlabel("Target SNR (dB)") |
145 | | -plt.ylabel("Measured SNR (dB)") |
146 | | -plt.title("Theoretical vs. Measured SNR") |
147 | | -plt.legend() |
148 | | - |
149 | | -# Plot PSNR values |
150 | | -plt.subplot(1, 2, 2) |
151 | | -plt.plot(target_snrs, measured_psnrs, "ro-", linewidth=2) |
152 | | -plt.grid(True) |
153 | | -plt.xlabel("Target SNR (dB)") |
154 | | -plt.ylabel("PSNR (dB)") |
155 | | -plt.title("PSNR vs. Target SNR") |
156 | | - |
157 | | -plt.tight_layout() |
158 | | -plt.show() |
| 129 | +# SNR and PSNR Analysis |
| 130 | +# ==================== |
| 131 | +# Compare theoretical vs measured SNR values and examine PSNR behavior |
| 132 | +# to validate the AWGN channel model performance. |
| 133 | + |
| 134 | +fig = plot_snr_psnr_comparison(measured_metrics, "SNR and PSNR Validation") |
| 135 | +fig.show() |
159 | 136 |
|
160 | 137 | # %% |
161 | 138 | # Calculate Mean Squared Error (MSE) |
|
166 | 143 | for snr_db, output in outputs: |
167 | 144 | mse = np.mean((signal - output) ** 2) |
168 | 145 | mse_values.append((snr_db, mse)) |
169 | | - print(f"SNR: {snr_db} dB, MSE: {mse:.6f}") |
| 146 | + # MSE Analysis Results |
| 147 | + # =================== |
| 148 | + # SNR: {snr_db} dB, MSE: {mse:.6f} |
170 | 149 |
|
171 | 150 | # %% |
172 | 151 | # Plot SNR vs MSE |
173 | 152 | # ------------------------- |
174 | 153 | # Let's plot the relationship between SNR and MSE. |
175 | 154 |
|
176 | | -plt.figure(figsize=(8, 5)) |
| 155 | +# MSE vs SNR Relationship Analysis |
| 156 | +# =============================== |
| 157 | +# Examine the theoretical and measured relationship between |
| 158 | +# signal-to-noise ratio and mean squared error. |
| 159 | + |
177 | 160 | snr_levels = [snr for snr, _ in mse_values] |
178 | 161 | mse_vals = [mse for _, mse in mse_values] |
179 | | - |
180 | | -plt.plot(snr_levels, mse_vals, "o-", linewidth=2) |
181 | | -plt.grid(True) |
182 | | -plt.xlabel("SNR (dB)") |
183 | | -plt.ylabel("Mean Squared Error") |
184 | | -plt.title("SNR vs. Mean Squared Error") |
185 | | -plt.yscale("log") # Use logarithmic scale for MSE |
186 | | - |
187 | | -# Add theoretical MSE curve: MSE = noise_power = signal_power / 10^(SNR/10) |
188 | | -snr_range = np.linspace(-6, 21, 100) |
189 | 162 | signal_power = amplitude**2 / 2 |
190 | | -theoretical_mse = signal_power / np.power(10, snr_range / 10) |
191 | | -plt.plot(snr_range, theoretical_mse, "k--", linewidth=1, label="Theoretical") |
192 | | -plt.legend() |
193 | 163 |
|
194 | | -plt.show() |
| 164 | +fig = plot_snr_vs_mse(snr_levels, mse_vals, signal_power, "SNR vs Mean Squared Error Analysis") |
| 165 | +fig.show() |
195 | 166 |
|
196 | 167 | # %% |
197 | 168 | # Conclusion |
|
0 commit comments