Skip to content

Commit d02d456

Browse files
committed
#34 refactor: update plotting imports and utilize PlottingUtils for consistent style setup
1 parent 9b119f4 commit d02d456

File tree

11 files changed

+311
-117
lines changed

11 files changed

+311
-117
lines changed

docs/examples/constraints/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Constraint handling and optimization techniques for communication systems design
2727

2828
.. raw:: html
2929

30-
<div class="sphx-glr-thumbcontainer" tooltip="This example demonstrates how to combine multiple constraints in Kaira to satisfy complex signal requirements. We'll explore the composition utilities and see how constraints can be sequentially applied to meet practical transmission specifications.">
30+
<div class="sphx-glr-thumbcontainer" tooltip="ax.text(0.5, 0.5, 'Spectral Mask Constraint Effects (Visualization placeholder)', ha='center', va='center', transform=ax.transAxes, fontsize=14) ax.set_title('Spectral Mask Constraint Effects', fontsize=16, fontweight='bold') plt.show()=============================================================================================================================================================== This example demonstrates how to combine multiple constraints in Kaira to satisfy complex signal requirements. We'll explore the composition utilities and see how constraints can be sequentially applied to meet practical transmission specifications.">
3131

3232
.. only:: html
3333

examples/channels/plot_awgn_channel.py

Lines changed: 100 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,22 @@
1515
# -------------------------------
1616
# We start by importing the necessary modules and setting up the environment.
1717

18+
import matplotlib.pyplot as plt
1819
import numpy as np
1920
import torch
2021

21-
from examples.example_utils.plotting import (
22-
plot_signal_noise_comparison,
23-
plot_snr_psnr_comparison,
24-
plot_snr_vs_mse,
25-
setup_plotting_style,
26-
)
2722
from kaira.channels import AWGNChannel
2823
from kaira.metrics.image import PSNR
2924
from kaira.metrics.signal import SNR
3025
from kaira.utils import snr_to_noise_power
26+
from kaira.utils.plotting import PlottingUtils
3127

3228
# Set random seed for reproducibility
3329
torch.manual_seed(42)
3430
np.random.seed(42)
3531

3632
# Configure plotting style
37-
setup_plotting_style()
33+
PlottingUtils.setup_plotting_style()
3834

3935
# %%
4036
# Create Sample Signal
@@ -115,7 +111,55 @@
115111
# Compare the original clean signal with signals processed through
116112
# AWGN channels at different SNR levels to observe noise effects.
117113

118-
fig = plot_signal_noise_comparison(t, signal, outputs, measured_metrics, "AWGN Channel Effects on Signal Transmission")
114+
fig, axes = plt.subplots(2, 2, figsize=(15, 10), constrained_layout=True)
115+
fig.suptitle("AWGN Channel Effects on Signal Transmission", fontsize=16, fontweight="bold")
116+
117+
# Plot original and noisy signals
118+
ax1 = axes[0, 0]
119+
ax1.plot(t, signal, "b-", linewidth=2, label="Original Signal", alpha=0.8)
120+
for i, (snr_db, output) in enumerate(outputs[:3]): # Show first 3 for clarity
121+
color = PlottingUtils.MODERN_PALETTE[i % len(PlottingUtils.MODERN_PALETTE)]
122+
ax1.plot(t, output, "--", color=color, linewidth=1.5, alpha=0.7, label=f"SNR: {snr_db} dB")
123+
ax1.set_xlabel("Time")
124+
ax1.set_ylabel("Amplitude")
125+
ax1.set_title("Signal Comparison")
126+
ax1.legend()
127+
ax1.grid(True, alpha=0.3)
128+
129+
# Plot SNR comparison
130+
ax2 = axes[0, 1]
131+
target_snrs = [metric["target_snr_db"] for metric in measured_metrics]
132+
measured_snrs = [metric["measured_snr_db"] for metric in measured_metrics]
133+
ax2.plot(target_snrs, measured_snrs, "o-", color=PlottingUtils.MODERN_PALETTE[0], linewidth=2, markersize=8)
134+
ax2.plot(target_snrs, target_snrs, "--", color="gray", alpha=0.7, label="Ideal (Target = Measured)")
135+
ax2.set_xlabel("Target SNR (dB)")
136+
ax2.set_ylabel("Measured SNR (dB)")
137+
ax2.set_title("SNR Validation")
138+
ax2.legend()
139+
ax2.grid(True, alpha=0.3)
140+
141+
# Plot PSNR values
142+
ax3 = axes[1, 0]
143+
psnr_values = [metric["measured_psnr_db"] for metric in measured_metrics]
144+
ax3.plot(target_snrs, psnr_values, "s-", color=PlottingUtils.MODERN_PALETTE[1], linewidth=2, markersize=8)
145+
ax3.set_xlabel("Target SNR (dB)")
146+
ax3.set_ylabel("PSNR (dB)")
147+
ax3.set_title("PSNR vs Target SNR")
148+
ax3.grid(True, alpha=0.3)
149+
150+
# Plot noise effects on signal
151+
ax4 = axes[1, 1]
152+
for i, (snr_db, output) in enumerate(outputs):
153+
noise = output - signal
154+
noise_power = np.mean(noise**2)
155+
ax4.bar(i, noise_power, color=PlottingUtils.MODERN_PALETTE[i % len(PlottingUtils.MODERN_PALETTE)], alpha=0.7)
156+
ax4.set_xlabel("Channel Index")
157+
ax4.set_ylabel("Noise Power")
158+
ax4.set_title("Noise Power by Channel")
159+
ax4.set_xticks(range(len(outputs)))
160+
ax4.set_xticklabels([f"{snr}dB" for snr, _ in outputs], rotation=45)
161+
ax4.grid(True, alpha=0.3)
162+
119163
fig.show()
120164

121165
# %%
@@ -128,7 +172,29 @@
128172
# Compare theoretical vs measured SNR values and examine PSNR behavior
129173
# to validate the AWGN channel model performance.
130174

131-
fig = plot_snr_psnr_comparison(measured_metrics, "SNR and PSNR Validation")
175+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), constrained_layout=True)
176+
fig.suptitle("SNR and PSNR Validation", fontsize=16, fontweight="bold")
177+
178+
# SNR comparison
179+
target_snrs = [metric["target_snr_db"] for metric in measured_metrics]
180+
measured_snrs = [metric["measured_snr_db"] for metric in measured_metrics]
181+
psnr_values = [metric["measured_psnr_db"] for metric in measured_metrics]
182+
183+
ax1.scatter(target_snrs, measured_snrs, color=PlottingUtils.MODERN_PALETTE[0], s=100, alpha=0.7, label="Measured SNR")
184+
ax1.plot(target_snrs, target_snrs, "--", color="gray", alpha=0.7, label="Ideal (Target = Measured)")
185+
ax1.set_xlabel("Target SNR (dB)")
186+
ax1.set_ylabel("Measured SNR (dB)")
187+
ax1.set_title("SNR Validation")
188+
ax1.legend()
189+
ax1.grid(True, alpha=0.3)
190+
191+
# PSNR vs Target SNR
192+
ax2.plot(target_snrs, psnr_values, "o-", color=PlottingUtils.MODERN_PALETTE[1], linewidth=2, markersize=8)
193+
ax2.set_xlabel("Target SNR (dB)")
194+
ax2.set_ylabel("PSNR (dB)")
195+
ax2.set_title("PSNR vs Target SNR")
196+
ax2.grid(True, alpha=0.3)
197+
132198
fig.show()
133199

134200
# %%
@@ -158,7 +224,31 @@
158224
mse_vals = [mse for _, mse in mse_values]
159225
signal_power = amplitude**2 / 2
160226

161-
fig = plot_snr_vs_mse(snr_levels, mse_vals, signal_power, "SNR vs Mean Squared Error Analysis")
227+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), constrained_layout=True)
228+
fig.suptitle("SNR vs Mean Squared Error Analysis", fontsize=16, fontweight="bold")
229+
230+
# Plot MSE vs SNR
231+
ax1.semilogy(snr_levels, mse_vals, "o-", color=PlottingUtils.MODERN_PALETTE[0], linewidth=2, markersize=8)
232+
ax1.set_xlabel("SNR (dB)")
233+
ax1.set_ylabel("MSE")
234+
ax1.set_title("Measured MSE vs SNR")
235+
ax1.grid(True, alpha=0.3)
236+
237+
# Plot theoretical MSE (noise power)
238+
theoretical_mse = []
239+
for snr_db in snr_levels:
240+
snr_linear = 10 ** (snr_db / 10)
241+
theoretical_noise_power = signal_power / snr_linear
242+
theoretical_mse.append(theoretical_noise_power)
243+
244+
ax2.semilogy(snr_levels, mse_vals, "o-", color=PlottingUtils.MODERN_PALETTE[0], linewidth=2, markersize=8, label="Measured MSE")
245+
ax2.semilogy(snr_levels, theoretical_mse, "--", color=PlottingUtils.MODERN_PALETTE[1], linewidth=2, label="Theoretical MSE")
246+
ax2.set_xlabel("SNR (dB)")
247+
ax2.set_ylabel("MSE")
248+
ax2.set_title("Measured vs Theoretical MSE")
249+
ax2.legend()
250+
ax2.grid(True, alpha=0.3)
251+
162252
fig.show()
163253

164254
# %%

examples/channels/plot_binary_channels.py

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,19 @@
1818
# -------------------------------
1919
# We start by importing the necessary modules and setting up the environment.
2020

21+
import matplotlib.pyplot as plt
2122
import numpy as np
2223
import torch
2324

24-
from examples.example_utils.plotting import (
25-
plot_binary_channel_comparison,
26-
plot_channel_capacity_analysis,
27-
plot_channel_error_rates,
28-
plot_transition_matrices,
29-
setup_plotting_style,
30-
)
3125
from kaira.channels import BinaryErasureChannel, BinarySymmetricChannel, BinaryZChannel
26+
from kaira.utils.plotting import PlottingUtils
3227

3328
# Set random seed for reproducibility
3429
torch.manual_seed(42)
3530
np.random.seed(42)
3631

3732
# Configure plotting style
38-
setup_plotting_style()
33+
PlottingUtils.setup_plotting_style()
3934

4035
# %%
4136
# Generate Binary Data
@@ -167,7 +162,32 @@
167162

168163
# Visualize channel effects
169164
original_data = binary_data[0].numpy()
170-
fig = plot_binary_channel_comparison(original_data, channel_outputs, segment_start, segment_length, "Binary Channel Effects Comparison")
165+
166+
# Create binary channel comparison plot
167+
fig, axes = plt.subplots(2, 2, figsize=(15, 10), constrained_layout=True)
168+
fig.suptitle("Binary Channel Effects Comparison", fontsize=16, fontweight="bold")
169+
170+
# Plot original data segment
171+
segment_end = segment_start + segment_length
172+
ax1 = axes[0, 0]
173+
ax1.plot(range(segment_length), original_data[segment_start:segment_end], "o-", color=PlottingUtils.MODERN_PALETTE[0], linewidth=2, markersize=6, label="Original")
174+
ax1.set_title("Original Binary Data")
175+
ax1.set_xlabel("Bit Index")
176+
ax1.set_ylabel("Bit Value")
177+
ax1.set_ylim(-0.1, 1.1)
178+
ax1.grid(True, alpha=0.3)
179+
180+
# Plot each channel output
181+
for i, (channel_name, output, error_prob) in enumerate(channel_outputs[:3]):
182+
ax = axes.flat[i + 1]
183+
ax.plot(range(segment_length), output[segment_start:segment_end], "o-", color=PlottingUtils.MODERN_PALETTE[(i + 1) % len(PlottingUtils.MODERN_PALETTE)], linewidth=2, markersize=6, label=f"{channel_name} (p={error_prob:.2f})")
184+
ax.set_title(f"{channel_name} Output")
185+
ax.set_xlabel("Bit Index")
186+
ax.set_ylabel("Bit Value")
187+
ax.set_ylim(-0.1, 1.1)
188+
ax.grid(True, alpha=0.3)
189+
ax.legend()
190+
171191
fig.show()
172192

173193
# %%
@@ -185,15 +205,29 @@
185205
observed_bsc = [err_rate for _, _, err_rate in bsc_outputs]
186206

187207
# Plot BSC error rates
188-
fig1 = plot_channel_error_rates(error_probs, theoretical_bsc, observed_bsc, ["BSC"], "Binary Symmetric Channel Error Rates")
208+
fig1, ax1 = plt.subplots(figsize=(10, 6), constrained_layout=True)
209+
ax1.plot(error_probs, theoretical_bsc, "o-", color=PlottingUtils.MODERN_PALETTE[0], linewidth=2, markersize=8, label="Theoretical BSC")
210+
ax1.plot(error_probs, observed_bsc, "s-", color=PlottingUtils.MODERN_PALETTE[1], linewidth=2, markersize=8, label="Observed BSC")
211+
ax1.set_xlabel("Error Probability")
212+
ax1.set_ylabel("Error Rate")
213+
ax1.set_title("Binary Symmetric Channel Error Rates", fontsize=14, fontweight="bold")
214+
ax1.legend()
215+
ax1.grid(True, alpha=0.3)
189216
fig1.show()
190217

191218
# Prepare BEC erasure rate data
192219
theoretical_bec = erasure_probs # Theoretical erasure rate equals p
193220
observed_bec = [erasure_rate for _, _, erasure_rate in bec_outputs]
194221

195222
# Plot BEC erasure rates
196-
fig2 = plot_channel_error_rates(erasure_probs, theoretical_bec, observed_bec, ["BEC"], "Binary Erasure Channel Erasure Rates")
223+
fig2, ax2 = plt.subplots(figsize=(10, 6), constrained_layout=True)
224+
ax2.plot(erasure_probs, theoretical_bec, "o-", color=PlottingUtils.MODERN_PALETTE[0], linewidth=2, markersize=8, label="Theoretical BEC")
225+
ax2.plot(erasure_probs, observed_bec, "s-", color=PlottingUtils.MODERN_PALETTE[1], linewidth=2, markersize=8, label="Observed BEC")
226+
ax2.set_xlabel("Erasure Probability")
227+
ax2.set_ylabel("Erasure Rate")
228+
ax2.set_title("Binary Erasure Channel Erasure Rates", fontsize=14, fontweight="bold")
229+
ax2.legend()
230+
ax2.grid(True, alpha=0.3)
197231
fig2.show()
198232

199233
# Prepare Z-Channel error rate data
@@ -203,7 +237,14 @@
203237
observed_z = [err_rate * p_one for _, _, err_rate in z_outputs]
204238

205239
# Plot Z-Channel error rates
206-
fig3 = plot_channel_error_rates(z_error_probs, theoretical_z, observed_z, ["Z-Channel"], "Z-Channel Error Rates")
240+
fig3, ax3 = plt.subplots(figsize=(10, 6), constrained_layout=True)
241+
ax3.plot(z_error_probs, theoretical_z, "o-", color=PlottingUtils.MODERN_PALETTE[0], linewidth=2, markersize=8, label="Theoretical Z-Channel")
242+
ax3.plot(z_error_probs, observed_z, "s-", color=PlottingUtils.MODERN_PALETTE[1], linewidth=2, markersize=8, label="Observed Z-Channel")
243+
ax3.set_xlabel("Error Probability")
244+
ax3.set_ylabel("Error Rate")
245+
ax3.set_title("Z-Channel Error Rates", fontsize=14, fontweight="bold")
246+
ax3.legend()
247+
ax3.grid(True, alpha=0.3)
207248
fig3.show()
208249

209250
# %%
@@ -230,7 +271,24 @@
230271
# Plot transition matrices
231272
matrices = [("Binary Symmetric Channel", bsc_matrix, p_bsc), ("Binary Erasure Channel", bec_matrix, p_bec), ("Z-Channel", z_matrix, p_z)]
232273

233-
fig4 = plot_transition_matrices(matrices, "Binary Channel Transition Matrices")
274+
fig4, axes = plt.subplots(1, 3, figsize=(15, 5), constrained_layout=True)
275+
fig4.suptitle("Binary Channel Transition Matrices", fontsize=16, fontweight="bold")
276+
277+
for i, (name, matrix, p) in enumerate(matrices):
278+
ax = axes[i]
279+
im = ax.imshow(matrix, cmap=PlottingUtils.MATRIX_CMAP, interpolation="nearest", aspect="auto")
280+
ax.set_title(f"{name}\n(p={p:.2f})")
281+
ax.set_xlabel("Output")
282+
ax.set_ylabel("Input")
283+
284+
# Add text annotations
285+
for row in range(matrix.shape[0]):
286+
for col in range(matrix.shape[1]):
287+
color = "white" if matrix[row, col] > 0.5 else "black"
288+
ax.text(col, row, f"{matrix[row, col]:.2f}", ha="center", va="center", color=color, fontsize=12, fontweight="bold")
289+
290+
plt.colorbar(im, ax=ax, shrink=0.8)
291+
234292
fig4.show()
235293

236294
# %%
@@ -278,7 +336,7 @@ def calculate_z_capacity(p):
278336
# Plot capacity analysis
279337
capacities = {"BSC": bsc_capacities, "BEC": bec_capacities, "Z-Channel": z_capacities}
280338

281-
fig5 = plot_channel_capacity_analysis(p_range, capacities, "Binary Channel Capacity Analysis")
339+
fig5 = PlottingUtils.plot_capacity_analysis(p_range, capacities, "Binary Channel Capacity Analysis")
282340
fig5.show()
283341

284342
# %%

examples/channels/plot_fading_channels.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@
1717
import torch
1818
from scipy import signal
1919

20-
# Plotting imports
21-
from examples.example_utils.plotting import setup_plotting_style
2220
from kaira.channels import AWGNChannel, FlatFadingChannel, PerfectChannel
2321
from kaira.metrics.signal import BitErrorRate
2422
from kaira.modulations import QPSKModulator
2523
from kaira.modulations.utils import calculate_theoretical_ber
2624
from kaira.utils import snr_to_noise_power
2725

28-
setup_plotting_style()
26+
# Plotting imports
27+
from kaira.utils.plotting import PlottingUtils
28+
29+
PlottingUtils.setup_plotting_style()
2930

3031
# %%
3132
# Imports and Setup

0 commit comments

Comments
 (0)