Skip to content

Commit 3619acd

Browse files
committed
Refactor Examples: Extract Plotting Code and Convert Print Statements to Comments #34
feat: Refactor plotting code and improve example readability - Created a new utilities module for reusable plotting functions in . - Refactored 11 example files to utilize the new plotting utilities, enhancing code organization and readability. - Converted print statements to structured comments for better documentation. - Removed conditionals, streamlining plotting integration. - Maintained full functionality of visualizations while improving code clarity and separation of concerns. - Established a consistent commenting structure across examples. - Documented progress and future enhancements for ongoing refactoring efforts.
1 parent 4ac2827 commit 3619acd

19 files changed

+3687
-1831
lines changed

examples/channels/plot_awgn_channel.py

Lines changed: 39 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,17 @@
1515
# -------------------------------
1616
# We start by importing the necessary modules and setting up the environment.
1717

18-
import matplotlib.pyplot as plt
1918
import numpy as np
2019
import torch
2120

21+
from examples.utils.plotting import (
22+
setup_plotting_style,
23+
plot_signal_noise_comparison,
24+
plot_snr_psnr_comparison,
25+
plot_snr_vs_mse,
26+
plot_noise_level_analysis
27+
)
28+
2229
from kaira.channels import AWGNChannel
2330
from kaira.metrics.image import PSNR
2431
from kaira.metrics.signal import SNR
@@ -28,6 +35,9 @@
2835
torch.manual_seed(42)
2936
np.random.seed(42)
3037

38+
# Configure plotting style
39+
setup_plotting_style()
40+
3141
# %%
3242
# Create Sample Signal
3343
# ------------------------------------
@@ -93,69 +103,36 @@
93103
measured_metrics.append({"target_snr_db": snr_db, "measured_snr_db": measured_snr, "measured_psnr_db": measured_psnr})
94104

95105
# Ensure we're using float values for string formatting
96-
print(f"Target SNR: {snr_db:.1f} dB, Measured SNR: {measured_snr:.1f} dB, PSNR: {measured_psnr:.1f} dB")
106+
# AWGN Channel Performance Analysis
107+
# ===============================
108+
# Target SNR: {snr_db:.1f} dB, Measured SNR: {measured_snr:.1f} dB, PSNR: {measured_psnr:.1f} dB
97109

98110
# %%
99111
# Visualize the Results
100112
# -------------------------------------
101113
# Let's visualize how different SNR levels affect the transmitted signal.
102114

103-
plt.figure(figsize=(10, 8))
104-
105-
# Plot the original signal
106-
plt.subplot(len(snr_levels_db) + 1, 1, 1)
107-
plt.plot(t, signal, "b-", linewidth=1.5)
108-
plt.title("Original Signal")
109-
plt.grid(True)
110-
plt.ylabel("Amplitude")
111-
plt.xlim([0, 1])
112-
113-
# Plot each noisy signal
114-
for i, (snr_db, output) in enumerate(outputs):
115-
plt.subplot(len(snr_levels_db) + 1, 1, i + 2)
116-
plt.plot(t, output, "r-", alpha=0.8)
117-
measured_snr = measured_metrics[i]["measured_snr_db"]
118-
plt.title(f"AWGN Channel (Target SNR = {snr_db} dB, Measured SNR = {measured_snr:.1f} dB)")
119-
plt.grid(True)
120-
plt.ylabel("Amplitude")
121-
if i == len(outputs) - 1:
122-
plt.xlabel("Time (s)")
123-
plt.xlim([0, 1])
124-
125-
plt.tight_layout()
126-
plt.show()
115+
# Visualization: Signal Degradation with Noise
116+
# ============================================
117+
# Compare the original clean signal with signals processed through
118+
# AWGN channels at different SNR levels to observe noise effects.
119+
120+
fig = plot_signal_noise_comparison(t, signal, outputs, measured_metrics,
121+
"AWGN Channel Effects on Signal Transmission")
122+
fig.show()
127123

128124
# %%
129125
# Compare Theoretical and Measured SNR Values
130126
# ------------------------------------------------------------------------------
131127
# Let's compare the target SNR values with what we actually measured.
132128

133-
plt.figure(figsize=(10, 5))
134-
135-
target_snrs = [metric["target_snr_db"] for metric in measured_metrics]
136-
measured_snrs = [metric["measured_snr_db"] for metric in measured_metrics]
137-
measured_psnrs = [metric["measured_psnr_db"] for metric in measured_metrics]
138-
139-
# Plot SNR comparison
140-
plt.subplot(1, 2, 1)
141-
plt.plot(target_snrs, measured_snrs, "bo-", linewidth=2, label="Measured SNR")
142-
plt.plot(target_snrs, target_snrs, "k--", linewidth=1, label="Theoretical (Target)")
143-
plt.grid(True)
144-
plt.xlabel("Target SNR (dB)")
145-
plt.ylabel("Measured SNR (dB)")
146-
plt.title("Theoretical vs. Measured SNR")
147-
plt.legend()
148-
149-
# Plot PSNR values
150-
plt.subplot(1, 2, 2)
151-
plt.plot(target_snrs, measured_psnrs, "ro-", linewidth=2)
152-
plt.grid(True)
153-
plt.xlabel("Target SNR (dB)")
154-
plt.ylabel("PSNR (dB)")
155-
plt.title("PSNR vs. Target SNR")
156-
157-
plt.tight_layout()
158-
plt.show()
129+
# SNR and PSNR Analysis
130+
# ====================
131+
# Compare theoretical vs measured SNR values and examine PSNR behavior
132+
# to validate the AWGN channel model performance.
133+
134+
fig = plot_snr_psnr_comparison(measured_metrics, "SNR and PSNR Validation")
135+
fig.show()
159136

160137
# %%
161138
# Calculate Mean Squared Error (MSE)
@@ -166,32 +143,26 @@
166143
for snr_db, output in outputs:
167144
mse = np.mean((signal - output) ** 2)
168145
mse_values.append((snr_db, mse))
169-
print(f"SNR: {snr_db} dB, MSE: {mse:.6f}")
146+
# MSE Analysis Results
147+
# ===================
148+
# SNR: {snr_db} dB, MSE: {mse:.6f}
170149

171150
# %%
172151
# Plot SNR vs MSE
173152
# -------------------------
174153
# Let's plot the relationship between SNR and MSE.
175154

176-
plt.figure(figsize=(8, 5))
155+
# MSE vs SNR Relationship Analysis
156+
# ===============================
157+
# Examine the theoretical and measured relationship between
158+
# signal-to-noise ratio and mean squared error.
159+
177160
snr_levels = [snr for snr, _ in mse_values]
178161
mse_vals = [mse for _, mse in mse_values]
179-
180-
plt.plot(snr_levels, mse_vals, "o-", linewidth=2)
181-
plt.grid(True)
182-
plt.xlabel("SNR (dB)")
183-
plt.ylabel("Mean Squared Error")
184-
plt.title("SNR vs. Mean Squared Error")
185-
plt.yscale("log") # Use logarithmic scale for MSE
186-
187-
# Add theoretical MSE curve: MSE = noise_power = signal_power / 10^(SNR/10)
188-
snr_range = np.linspace(-6, 21, 100)
189162
signal_power = amplitude**2 / 2
190-
theoretical_mse = signal_power / np.power(10, snr_range / 10)
191-
plt.plot(snr_range, theoretical_mse, "k--", linewidth=1, label="Theoretical")
192-
plt.legend()
193163

194-
plt.show()
164+
fig = plot_snr_vs_mse(snr_levels, mse_vals, signal_power, "SNR vs Mean Squared Error Analysis")
165+
fig.show()
195166

196167
# %%
197168
# Conclusion

0 commit comments

Comments
 (0)