Skip to content

Commit 98b0d2d

Browse files
committed
Enhance data handling and visualization in channel comparison and MAC integration examples; improve comments and remove unnecessary print statements in FEC tutorials.
1 parent 72646de commit 98b0d2d

File tree

8 files changed

+122
-101
lines changed

8 files changed

+122
-101
lines changed

examples/channels/plot_channel_comparison.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,20 @@ def apply_channels(data, channels, channel_names):
136136
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
137137
axes = axes.flatten()
138138

139-
# Add an "Original" plot
140-
axes[0].scatter(np.arange(100), continuous_data[:100], color=colors[0], alpha=0.7, s=40)
139+
140+
# Add an "Original" plot with safe conversion
141+
def safe_data_conversion(data, slice_obj=None):
142+
"""Safely convert data to numpy array for plotting."""
143+
if slice_obj is not None:
144+
data = data[slice_obj]
145+
data_np = data.numpy() if isinstance(data, torch.Tensor) else data
146+
if np.iscomplexobj(data_np):
147+
data_np = data_np.real
148+
return data_np
149+
150+
151+
original_data_safe = safe_data_conversion(continuous_data, slice(100))
152+
axes[0].scatter(np.arange(100), original_data_safe, color=colors[0], alpha=0.7, s=40)
141153
axes[0].set_title("Original Signal", fontsize=14)
142154
axes[0].set_xlabel("Sample Index", fontsize=12)
143155
axes[0].set_ylabel("Signal Value", fontsize=12)
@@ -437,14 +449,45 @@ def complex_to_real(x):
437449
awgn_15db = AWGNChannel(snr_db=15)
438450
fading_10db = FlatFadingChannel(fading_type="rayleigh", coherence_time=1, snr_db=10)
439451

452+
453+
# Helper function to safely convert channel output to real-valued numpy array
454+
def safe_numpy_conversion(tensor_data):
455+
"""Convert tensor to numpy array, handling complex values properly.
456+
457+
This function ensures that any complex-valued outputs from channels
458+
are properly converted to real values before being used in matplotlib
459+
plots, preventing ComplexWarning messages.
460+
461+
Parameters
462+
----------
463+
tensor_data : torch.Tensor
464+
The tensor data to convert, which may contain complex values.
465+
466+
Returns
467+
-------
468+
numpy.ndarray
469+
Real-valued numpy array suitable for plotting.
470+
"""
471+
numpy_data = tensor_data.numpy()
472+
if np.iscomplexobj(numpy_data):
473+
# If complex, take the real part (constellation plots expect real I/Q components)
474+
# Note: For constellation diagrams, the complex data should already be in I/Q format
475+
# where real/imag parts are stored as separate columns, so this is a fallback
476+
numpy_data = numpy_data.real
477+
# Only warn if the imaginary part contains significant values
478+
if tensor_data.dtype in [torch.complex64, torch.complex128]:
479+
print(f"Info: Complex tensor converted to real part for plotting. " f"Shape: {numpy_data.shape}")
480+
return numpy_data
481+
482+
440483
# Process data through channels
441-
qpsk_awgn_5db = awgn_5db(qpsk_data).numpy()
442-
qpsk_awgn_15db = awgn_15db(qpsk_data).numpy()
443-
qpsk_fading = fading_10db(qpsk_data).numpy()
484+
qpsk_awgn_5db = safe_numpy_conversion(awgn_5db(qpsk_data))
485+
qpsk_awgn_15db = safe_numpy_conversion(awgn_15db(qpsk_data))
486+
qpsk_fading = safe_numpy_conversion(fading_10db(qpsk_data))
444487

445-
qam_awgn_5db = awgn_5db(qam_data).numpy()
446-
qam_awgn_15db = awgn_15db(qam_data).numpy()
447-
qam_fading = fading_10db(qam_data).numpy()
488+
qam_awgn_5db = safe_numpy_conversion(awgn_5db(qam_data))
489+
qam_awgn_15db = safe_numpy_conversion(awgn_15db(qam_data))
490+
qam_fading = safe_numpy_conversion(fading_10db(qam_data))
448491

449492
# Create constellation plots
450493
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

examples/constraints/plot_basic_constraints.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,6 @@
336336
avg_power_signals = {name: torch.tensor(data).reshape(1, -1) for name, data in avg_power_results.items()}
337337

338338
# Compare constraints side by side
339-
print("\n=== Signal Properties Comparison ===")
340339
fig, axes = plt.subplots(2, 2, figsize=(15, 10), constrained_layout=True)
341340
fig.suptitle("Signal Properties Analysis - All Constraints", fontsize=16, fontweight="bold")
342341

examples/models/plot_afmodule.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -314,19 +314,22 @@ def forward(self, x, snr):
314314

315315
# %%
316316
# Conclusion
317-
# ---------------
317+
# ----------
318+
#
318319
# In this example, we explored the Attention-Feature Module (AFModule), a component
319320
# designed to help neural networks adapt to varying channel conditions in wireless
320321
# communication systems.
321322
#
322-
# Key points:
323-
# - AFModule recalibrates feature maps based on channel state information
324-
# - It can work with different input tensor dimensions (2D, 3D, 4D)
325-
# - It helps maintain performance across different channel conditions (like varying SNRs)
326-
# - The module can adapt to different feature sizes dynamically
323+
# **Key Points:**
324+
#
325+
# * AFModule recalibrates feature maps based on channel state information
326+
# * It can work with different input tensor dimensions (2D, 3D, 4D)
327+
# * It helps maintain performance across different channel conditions (like varying SNRs)
328+
# * The module can adapt to different feature sizes dynamically
327329
#
328330
# The AFModule is particularly useful in deep learning-based communication systems
329331
# that need to operate reliably in varying channel conditions.
330332
#
331-
# References:
332-
# - :cite:`xu2021wireless`
333+
# **References:**
334+
#
335+
# * :cite:`xu2021wireless`

examples/models/plot_channel_aware_base_model.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ def forward(self, x: torch.Tensor, csi: torch.Tensor, *args, **kwargs) -> torch.
8181
csi_dim = 2
8282
model = SimpleChannelAwareEncoder(input_dim, output_dim, csi_dim)
8383

84-
print("Created SimpleChannelAwareEncoder:")
85-
print(f" Input dimension: {input_dim}")
86-
print(f" Output dimension: {output_dim}")
87-
print(f" CSI dimension: {csi_dim}")
84+
# Created SimpleChannelAwareEncoder:
85+
# Input dimension: {input_dim}
86+
# Output dimension: {output_dim}
87+
# CSI dimension: {csi_dim}
8888

8989
# %%
9090
# Demonstrating CSI Validation and Normalization
@@ -96,81 +96,79 @@ def forward(self, x: torch.Tensor, csi: torch.Tensor, *args, **kwargs) -> torch.
9696
x = torch.randn(batch_size, input_dim)
9797

9898
# Different CSI examples
99-
print("\n=== CSI Validation Examples ===")
100-
10199
# Valid CSI
102100
csi_valid = torch.tensor([[10.0, 0.5], [15.0, 0.8], [5.0, 0.3], [20.0, 0.9], [12.0, 0.6], [8.0, 0.4], [18.0, 0.7], [14.0, 0.65]])
103-
print(f"Valid CSI shape: {csi_valid.shape}")
101+
# Valid CSI shape: {csi_valid.shape}
104102

105103
try:
106104
validated_csi = model.validate_csi(csi_valid)
107105
print("✓ CSI validation passed")
106+
except ValueError as e:
107+
# ✗ CSI validation failed due to value error
108+
print("✗ CSI validation failed: " + str(e))
109+
validated_csi = None
108110
except Exception as e:
109-
print(f"✗ CSI validation failed: {e}")
111+
# ✗ CSI validation failed due to unexpected error
112+
print("✗ Unexpected error during CSI validation: " + str(e))
113+
validated_csi = None
110114

111115
# Test normalization
112-
print("\n=== CSI Normalization Examples ===")
113-
114116
# MinMax normalization
115117
normalized_minmax = model.normalize_csi(csi_valid, method="minmax", target_range=(0, 1))
116-
print(f"Original CSI range: [{csi_valid.min():.2f}, {csi_valid.max():.2f}]")
117-
print(f"MinMax normalized range: [{normalized_minmax.min():.2f}, {normalized_minmax.max():.2f}]")
118+
# Original CSI range: [{csi_valid.min():.2f}, {csi_valid.max():.2f}]
119+
# MinMax normalized range: [{normalized_minmax.min():.2f}, {normalized_minmax.max():.2f}]
118120

119121
# Z-score normalization
120122
normalized_zscore = model.normalize_csi(csi_valid, method="zscore")
121-
print(f"Z-score normalized mean: {normalized_zscore.mean():.4f}, std: {normalized_zscore.std():.4f}")
123+
# Z-score normalized mean: {normalized_zscore.mean():.4f}, std: {normalized_zscore.std():.4f}
122124

123125
# %%
124126
# Working with the AFModule
125127
# -------------------------
126128
# The AFModule has been updated to use ChannelAwareBaseModel. Let's demonstrate its usage.
127129

128-
print("\n=== AFModule (Channel-Aware) Example ===")
129-
130130
# Create AFModule
131131
N = 64 # Number of feature channels
132132
csi_length = 1 # CSI vector length
133133
af_module = AFModule(N=N, csi_length=csi_length)
134134

135135
# Create feature maps (4D tensor for image-like data)
136136
feature_maps = torch.randn(batch_size, N, 8, 8)
137-
print(f"Feature maps shape: {feature_maps.shape}")
137+
# Feature maps shape: {feature_maps.shape}
138138

139139
# Create CSI for AFModule (SNR values in dB)
140140
snr_values = torch.tensor([10.0, 15.0, 5.0, 20.0, 12.0, 8.0, 18.0, 14.0])
141141
csi_af = snr_values.unsqueeze(1) # Shape: [batch_size, 1]
142-
print(f"CSI shape for AFModule: {csi_af.shape}")
142+
# CSI shape for AFModule: {csi_af.shape}
143143

144144
# Apply AFModule
145145
with torch.no_grad():
146146
modulated_features = af_module(feature_maps, csi=csi_af)
147147

148-
print(f"Modulated features shape: {modulated_features.shape}")
149-
print(f"Feature modulation factor range: [{(modulated_features/feature_maps).min():.3f}, {(modulated_features/feature_maps).max():.3f}]")
148+
# Modulated features shape: {modulated_features.shape}
149+
# Feature modulation factor range: [{(modulated_features/feature_maps).min():.3f}, {(modulated_features/feature_maps).max():.3f}]
150150

151151
# %%
152152
# CSI Feature Extraction
153153
# ----------------------
154154
# The base class provides methods to extract useful features from CSI.
155155

156-
print("\n=== CSI Feature Extraction ===")
157-
158156
csi_features = model.extract_csi_features(csi_valid)
159-
print("Extracted CSI features:")
157+
# Extracted CSI features:
160158
for feature_name, feature_value in csi_features.items():
161159
if isinstance(feature_value, torch.Tensor):
162160
if feature_value.numel() == 1:
163-
print(f" {feature_name}: {feature_value.item():.4f}")
161+
# Single-value feature: {feature_name}: {feature_value.item():.4f}
162+
pass
164163
else:
165-
print(f" {feature_name}: {feature_value.tolist()}")
164+
# Multi-value feature: {feature_name}: {feature_value.tolist()}
165+
pass
166166

167167
# %%
168168
# Visualization of CSI Effects
169169
# ----------------------------
170170
# Let's visualize how different CSI values affect model outputs.
171171

172-
print("\n=== Visualizing CSI Effects ===")
173-
174172
# Generate a range of CSI values
175173
snr_range = torch.linspace(-5, 25, 31) # SNR from -5 to 25 dB
176174
quality_factor = torch.ones_like(snr_range) * 0.5 # Fixed quality factor
@@ -204,7 +202,7 @@ def forward(self, x: torch.Tensor, csi: torch.Tensor, *args, **kwargs) -> torch.
204202
# Plot 2: First few output dimensions vs SNR
205203
actual_output_dim = outputs.shape[1] if outputs.ndim > 1 else 1
206204
for i in range(min(4, actual_output_dim)):
207-
axes[0, 1].plot(snr_range.numpy(), outputs[:, i], label=f"Dim {i+1}")
205+
axes[0, 1].plot(snr_range.numpy(), outputs[:, i], label="Dim " + str(i + 1))
208206
axes[0, 1].set_xlabel("SNR (dB)")
209207
axes[0, 1].set_ylabel("Output Value")
210208
axes[0, 1].set_title("Output Dimensions vs CSI (SNR)")
@@ -260,8 +258,6 @@ def forward(self, x: torch.Tensor, csi: torch.Tensor, *args, **kwargs) -> torch.
260258
# ---------------------------------
261259
# Demonstrate how to extract and use CSI from channel outputs.
262260

263-
print("\n=== Integration with Channels ===")
264-
265261
# Create channels that might provide CSI
266262
awgn_channel = AWGNChannel(snr_db=15.0)
267263
fading_channel = FlatFadingChannel(fading_type="rayleigh", coherence_time=50, snr_db=10.0)
@@ -273,24 +269,24 @@ def forward(self, x: torch.Tensor, csi: torch.Tensor, *args, **kwargs) -> torch.
273269
awgn_output = awgn_channel(test_signal)
274270
fading_output = fading_channel(test_signal)
275271

276-
print(f"AWGN channel output shape: {awgn_output.shape}")
277-
print(f"Fading channel output shape: {fading_output.shape}")
272+
# AWGN channel output shape: {awgn_output.shape}
273+
# Fading channel output shape: {fading_output.shape}
278274

279275
# Try to extract CSI (channels might not provide it directly)
280276
awgn_csi = model.extract_csi_from_channel_output(awgn_output)
281277
fading_csi = model.extract_csi_from_channel_output(fading_output)
282278

283-
print(f"Extracted CSI from AWGN: {awgn_csi}")
284-
print(f"Extracted CSI from Fading: {fading_csi}")
279+
# Extracted CSI from AWGN: {awgn_csi}
280+
# Extracted CSI from Fading: {fading_csi}
285281

286282
# Create CSI manually based on channel properties
287283
if hasattr(awgn_channel, "snr_db"):
288284
manual_awgn_csi = torch.full((batch_size, 1), awgn_channel.snr_db)
289-
print(f"Manual AWGN CSI: {manual_awgn_csi.mean().item():.1f} dB")
285+
print("Manual AWGN CSI: " + str(manual_awgn_csi.mean().item()) + " dB")
290286

291287
if hasattr(fading_channel, "snr_db"):
292288
manual_fading_csi = torch.full((batch_size, 1), fading_channel.snr_db)
293-
print(f"Manual Fading CSI: {manual_fading_csi.mean().item():.1f} dB")
289+
print("Manual Fading CSI: " + str(manual_fading_csi.mean().item()) + " dB")
294290

295291
# %%
296292
# Best Practices Summary

examples/models/plot_uplink_mac_integration.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def forward(self, user_messages, **kwargs):
8787
# %%
8888
# Setup Parameters
8989
# --------------------------------------------------------------------------
90-
print("=== UplinkMACChannel Integration Example ===\n")
90+
# === UplinkMACChannel Integration Example ===
9191

9292
# System parameters
9393
num_users = 3
@@ -100,19 +100,19 @@ def forward(self, user_messages, **kwargs):
100100
avg_noise_power = 0.1
101101
snr_db = 12.0
102102

103-
print("System Configuration:")
104-
print(f"- Number of users: {num_users}")
105-
print(f"- Message dimension: {message_dim}")
106-
print(f"- Code dimension: {code_dim}")
107-
print(f"- Batch size: {batch_size}")
108-
print(f"- SNR: {snr_db} dB")
109-
print(f"- Coherence time: {coherence_time}")
110-
print(f"- Average noise power: {avg_noise_power}\n")
103+
# System Configuration:
104+
# - Number of users: {num_users}
105+
# - Message dimension: {message_dim}
106+
# - Code dimension: {code_dim}
107+
# - Batch size: {batch_size}
108+
# - SNR: {snr_db} dB
109+
# - Coherence time: {coherence_time}
110+
# - Average noise power: {avg_noise_power}
111111

112112
# %%
113113
# Scenario 1: Shared Channel Configuration
114114
# --------------------------------------------------------------------------
115-
print("--- Scenario 1: Shared Channel Configuration ---")
115+
# --- Scenario 1: Shared Channel Configuration ---
116116

117117
# Create a shared Rayleigh fading channel for all users
118118
shared_channel = RayleighFadingChannel(coherence_time=coherence_time, avg_noise_power=avg_noise_power)
@@ -147,15 +147,15 @@ def forward(self, user_messages, **kwargs):
147147
for i in range(num_users):
148148
user_mse = mse_loss(reconstructed_messages[i], user_messages[i]).item()
149149
total_mse_shared += user_mse
150-
print(f"User {i+1} MSE (Shared Channel): {user_mse:.6f}")
150+
# User {i+1} MSE (Shared Channel): {user_mse:.6f}
151151

152152
avg_mse_shared = total_mse_shared / num_users
153-
print(f"Average MSE (Shared Channel): {avg_mse_shared:.6f}\n")
153+
# Average MSE (Shared Channel): {avg_mse_shared:.6f}
154154

155155
# %%
156156
# Scenario 2: Per-User Channel Configuration
157157
# --------------------------------------------------------------------------
158-
print("--- Scenario 2: Per-User Channel Configuration ---")
158+
# --- Scenario 2: Per-User Channel Configuration ---
159159

160160
# Create individual channels for each user with different characteristics
161161
per_user_channels = [
@@ -187,21 +187,21 @@ def forward(self, user_messages, **kwargs):
187187
for i in range(num_users):
188188
user_mse = mse_loss(reconstructed_messages_per_user[i], user_messages[i]).item()
189189
total_mse_per_user += user_mse
190-
print(f"User {i+1} MSE (Per-User Channel): {user_mse:.6f}")
190+
# User {i+1} MSE (Per-User Channel): {user_mse:.6f}
191191

192192
avg_mse_per_user = total_mse_per_user / num_users
193-
print(f"Average MSE (Per-User Channel): {avg_mse_per_user:.6f}\n")
193+
# Average MSE (Per-User Channel): {avg_mse_per_user:.6f}
194194

195195
# %%
196196
# Scenario 3: Combining Methods Comparison
197197
# --------------------------------------------------------------------------
198-
print("--- Scenario 3: Combining Methods Comparison ---")
198+
# --- Scenario 3: Combining Methods Comparison ---
199199

200200
combining_methods = ["sum", "weighted_sum"]
201201
combine_results = {}
202202

203203
for method in combining_methods:
204-
print(f"Testing combining method: {method}")
204+
# Testing combining method: {method}
205205

206206
# Create channels for this test
207207
test_channels = [FlatFadingChannel(fading_type="rayleigh", coherence_time=coherence_time, avg_noise_power=avg_noise_power) for _ in range(num_users)]
@@ -230,12 +230,12 @@ def forward(self, user_messages, **kwargs):
230230

231231
avg_mse_test = total_mse_test / num_users
232232
combine_results[method] = avg_mse_test
233-
print(f" Average MSE with {method}: {avg_mse_test:.6f}")
233+
# Average MSE with {method}: {avg_mse_test:.6f}
234234

235235
# %%
236236
# Scenario 4: Dynamic Parameter Updates
237237
# --------------------------------------------------------------------------
238-
print("\n--- Scenario 4: Dynamic Parameter Updates ---")
238+
# --- Scenario 4: Dynamic Parameter Updates ---
239239

240240
# Create UplinkMACChannel for dynamic updates
241241
dynamic_channels = [FlatFadingChannel(fading_type="rayleigh", coherence_time=coherence_time, avg_noise_power=avg_noise_power) for _ in range(num_users)]
@@ -254,7 +254,7 @@ def forward(self, user_messages, **kwargs):
254254
mse_results_dynamic = []
255255

256256
for snr in snr_values:
257-
print(f"Testing at SNR = {snr} dB...")
257+
# Testing at SNR = {snr} dB...
258258

259259
# Update individual channel parameters dynamically
260260
new_noise_power = 0.1 * (10 ** (-snr / 10)) # Adjust noise based on SNR
@@ -282,7 +282,7 @@ def forward(self, user_messages, **kwargs):
282282

283283
avg_mse_dynamic = total_mse_dynamic / num_users
284284
mse_results_dynamic.append(avg_mse_dynamic)
285-
print(f" Average MSE: {avg_mse_dynamic:.6f}")
285+
# Average MSE: {avg_mse_dynamic:.6f}
286286

287287
# %%
288288
# Plotting Results

0 commit comments

Comments
 (0)