|
18 | 18 | # ------------------------------- |
19 | 19 | # We start by importing the necessary modules and setting up the environment. |
20 | 20 |
|
| 21 | +import matplotlib.pyplot as plt |
21 | 22 | import numpy as np |
22 | 23 | import torch |
23 | 24 |
|
24 | | -from examples.example_utils.plotting import ( |
25 | | - plot_binary_channel_comparison, |
26 | | - plot_channel_capacity_analysis, |
27 | | - plot_channel_error_rates, |
28 | | - plot_transition_matrices, |
29 | | - setup_plotting_style, |
30 | | -) |
31 | 25 | from kaira.channels import BinaryErasureChannel, BinarySymmetricChannel, BinaryZChannel |
| 26 | +from kaira.utils.plotting import PlottingUtils |
32 | 27 |
|
33 | 28 | # Set random seed for reproducibility |
34 | 29 | torch.manual_seed(42) |
35 | 30 | np.random.seed(42) |
36 | 31 |
|
37 | 32 | # Configure plotting style |
38 | | -setup_plotting_style() |
| 33 | +PlottingUtils.setup_plotting_style() |
39 | 34 |
|
40 | 35 | # %% |
41 | 36 | # Generate Binary Data |
|
167 | 162 |
|
168 | 163 | # Visualize channel effects |
169 | 164 | original_data = binary_data[0].numpy() |
170 | | -fig = plot_binary_channel_comparison(original_data, channel_outputs, segment_start, segment_length, "Binary Channel Effects Comparison") |
| 165 | + |
| 166 | +# Create binary channel comparison plot |
| 167 | +fig, axes = plt.subplots(2, 2, figsize=(15, 10), constrained_layout=True) |
| 168 | +fig.suptitle("Binary Channel Effects Comparison", fontsize=16, fontweight="bold") |
| 169 | + |
| 170 | +# Plot original data segment |
| 171 | +segment_end = segment_start + segment_length |
| 172 | +ax1 = axes[0, 0] |
| 173 | +ax1.plot(range(segment_length), original_data[segment_start:segment_end], "o-", color=PlottingUtils.MODERN_PALETTE[0], linewidth=2, markersize=6, label="Original") |
| 174 | +ax1.set_title("Original Binary Data") |
| 175 | +ax1.set_xlabel("Bit Index") |
| 176 | +ax1.set_ylabel("Bit Value") |
| 177 | +ax1.set_ylim(-0.1, 1.1) |
| 178 | +ax1.grid(True, alpha=0.3) |
| 179 | + |
| 180 | +# Plot each channel output |
| 181 | +for i, (channel_name, output, error_prob) in enumerate(channel_outputs[:3]): |
| 182 | + ax = axes.flat[i + 1] |
| 183 | + ax.plot(range(segment_length), output[segment_start:segment_end], "o-", color=PlottingUtils.MODERN_PALETTE[(i + 1) % len(PlottingUtils.MODERN_PALETTE)], linewidth=2, markersize=6, label=f"{channel_name} (p={error_prob:.2f})") |
| 184 | + ax.set_title(f"{channel_name} Output") |
| 185 | + ax.set_xlabel("Bit Index") |
| 186 | + ax.set_ylabel("Bit Value") |
| 187 | + ax.set_ylim(-0.1, 1.1) |
| 188 | + ax.grid(True, alpha=0.3) |
| 189 | + ax.legend() |
| 190 | + |
171 | 191 | fig.show() |
172 | 192 |
|
173 | 193 | # %% |
|
185 | 205 | observed_bsc = [err_rate for _, _, err_rate in bsc_outputs] |
186 | 206 |
|
187 | 207 | # Plot BSC error rates |
188 | | -fig1 = plot_channel_error_rates(error_probs, theoretical_bsc, observed_bsc, ["BSC"], "Binary Symmetric Channel Error Rates") |
| 208 | +fig1, ax1 = plt.subplots(figsize=(10, 6), constrained_layout=True) |
| 209 | +ax1.plot(error_probs, theoretical_bsc, "o-", color=PlottingUtils.MODERN_PALETTE[0], linewidth=2, markersize=8, label="Theoretical BSC") |
| 210 | +ax1.plot(error_probs, observed_bsc, "s-", color=PlottingUtils.MODERN_PALETTE[1], linewidth=2, markersize=8, label="Observed BSC") |
| 211 | +ax1.set_xlabel("Error Probability") |
| 212 | +ax1.set_ylabel("Error Rate") |
| 213 | +ax1.set_title("Binary Symmetric Channel Error Rates", fontsize=14, fontweight="bold") |
| 214 | +ax1.legend() |
| 215 | +ax1.grid(True, alpha=0.3) |
189 | 216 | fig1.show() |
190 | 217 |
|
191 | 218 | # Prepare BEC erasure rate data |
192 | 219 | theoretical_bec = erasure_probs # Theoretical erasure rate equals p |
193 | 220 | observed_bec = [erasure_rate for _, _, erasure_rate in bec_outputs] |
194 | 221 |
|
195 | 222 | # Plot BEC erasure rates |
196 | | -fig2 = plot_channel_error_rates(erasure_probs, theoretical_bec, observed_bec, ["BEC"], "Binary Erasure Channel Erasure Rates") |
| 223 | +fig2, ax2 = plt.subplots(figsize=(10, 6), constrained_layout=True) |
| 224 | +ax2.plot(erasure_probs, theoretical_bec, "o-", color=PlottingUtils.MODERN_PALETTE[0], linewidth=2, markersize=8, label="Theoretical BEC") |
| 225 | +ax2.plot(erasure_probs, observed_bec, "s-", color=PlottingUtils.MODERN_PALETTE[1], linewidth=2, markersize=8, label="Observed BEC") |
| 226 | +ax2.set_xlabel("Erasure Probability") |
| 227 | +ax2.set_ylabel("Erasure Rate") |
| 228 | +ax2.set_title("Binary Erasure Channel Erasure Rates", fontsize=14, fontweight="bold") |
| 229 | +ax2.legend() |
| 230 | +ax2.grid(True, alpha=0.3) |
197 | 231 | fig2.show() |
198 | 232 |
|
199 | 233 | # Prepare Z-Channel error rate data |
|
203 | 237 | observed_z = [err_rate * p_one for _, _, err_rate in z_outputs] |
204 | 238 |
|
205 | 239 | # Plot Z-Channel error rates |
206 | | -fig3 = plot_channel_error_rates(z_error_probs, theoretical_z, observed_z, ["Z-Channel"], "Z-Channel Error Rates") |
| 240 | +fig3, ax3 = plt.subplots(figsize=(10, 6), constrained_layout=True) |
| 241 | +ax3.plot(z_error_probs, theoretical_z, "o-", color=PlottingUtils.MODERN_PALETTE[0], linewidth=2, markersize=8, label="Theoretical Z-Channel") |
| 242 | +ax3.plot(z_error_probs, observed_z, "s-", color=PlottingUtils.MODERN_PALETTE[1], linewidth=2, markersize=8, label="Observed Z-Channel") |
| 243 | +ax3.set_xlabel("Error Probability") |
| 244 | +ax3.set_ylabel("Error Rate") |
| 245 | +ax3.set_title("Z-Channel Error Rates", fontsize=14, fontweight="bold") |
| 246 | +ax3.legend() |
| 247 | +ax3.grid(True, alpha=0.3) |
207 | 248 | fig3.show() |
208 | 249 |
|
209 | 250 | # %% |
|
230 | 271 | # Plot transition matrices |
231 | 272 | matrices = [("Binary Symmetric Channel", bsc_matrix, p_bsc), ("Binary Erasure Channel", bec_matrix, p_bec), ("Z-Channel", z_matrix, p_z)] |
232 | 273 |
|
233 | | -fig4 = plot_transition_matrices(matrices, "Binary Channel Transition Matrices") |
| 274 | +fig4, axes = plt.subplots(1, 3, figsize=(15, 5), constrained_layout=True) |
| 275 | +fig4.suptitle("Binary Channel Transition Matrices", fontsize=16, fontweight="bold") |
| 276 | + |
| 277 | +for i, (name, matrix, p) in enumerate(matrices): |
| 278 | + ax = axes[i] |
| 279 | + im = ax.imshow(matrix, cmap=PlottingUtils.MATRIX_CMAP, interpolation="nearest", aspect="auto") |
| 280 | + ax.set_title(f"{name}\n(p={p:.2f})") |
| 281 | + ax.set_xlabel("Output") |
| 282 | + ax.set_ylabel("Input") |
| 283 | + |
| 284 | + # Add text annotations |
| 285 | + for row in range(matrix.shape[0]): |
| 286 | + for col in range(matrix.shape[1]): |
| 287 | + color = "white" if matrix[row, col] > 0.5 else "black" |
| 288 | + ax.text(col, row, f"{matrix[row, col]:.2f}", ha="center", va="center", color=color, fontsize=12, fontweight="bold") |
| 289 | + |
| 290 | + plt.colorbar(im, ax=ax, shrink=0.8) |
| 291 | + |
234 | 292 | fig4.show() |
235 | 293 |
|
236 | 294 | # %% |
@@ -278,7 +336,7 @@ def calculate_z_capacity(p): |
278 | 336 | # Plot capacity analysis |
279 | 337 | capacities = {"BSC": bsc_capacities, "BEC": bec_capacities, "Z-Channel": z_capacities} |
280 | 338 |
|
281 | | -fig5 = plot_channel_capacity_analysis(p_range, capacities, "Binary Channel Capacity Analysis") |
| 339 | +fig5 = PlottingUtils.plot_capacity_analysis(p_range, capacities, "Binary Channel Capacity Analysis") |
282 | 340 | fig5.show() |
283 | 341 |
|
284 | 342 | # %% |
|
0 commit comments