@@ -81,10 +81,10 @@ def forward(self, x: torch.Tensor, csi: torch.Tensor, *args, **kwargs) -> torch.
8181csi_dim = 2
8282model = SimpleChannelAwareEncoder (input_dim , output_dim , csi_dim )
8383
84- print ( " Created SimpleChannelAwareEncoder:" )
85- print ( f" Input dimension: { input_dim } " )
86- print ( f" Output dimension: { output_dim } " )
87- print ( f" CSI dimension: { csi_dim } " )
84+ # Created SimpleChannelAwareEncoder:
85+ # Input dimension: {input_dim}
86+ # Output dimension: {output_dim}
87+ # CSI dimension: {csi_dim}
8888
8989# %%
9090# Demonstrating CSI Validation and Normalization
@@ -96,81 +96,79 @@ def forward(self, x: torch.Tensor, csi: torch.Tensor, *args, **kwargs) -> torch.
9696x = torch .randn (batch_size , input_dim )
9797
9898# Different CSI examples
99- print ("\n === CSI Validation Examples ===" )
100-
10199# Valid CSI
102100csi_valid = torch .tensor ([[10.0 , 0.5 ], [15.0 , 0.8 ], [5.0 , 0.3 ], [20.0 , 0.9 ], [12.0 , 0.6 ], [8.0 , 0.4 ], [18.0 , 0.7 ], [14.0 , 0.65 ]])
103- print ( f" Valid CSI shape: { csi_valid .shape } " )
101+ # Valid CSI shape: {csi_valid.shape}
104102
105103try :
106104 validated_csi = model .validate_csi (csi_valid )
107105 print ("✓ CSI validation passed" )
106+ except ValueError as e :
107+ # ✗ CSI validation failed due to value error
108+ print ("✗ CSI validation failed: " + str (e ))
109+ validated_csi = None
108110except Exception as e :
109- print (f"✗ CSI validation failed: { e } " )
111+ # ✗ CSI validation failed due to unexpected error
112+ print ("✗ Unexpected error during CSI validation: " + str (e ))
113+ validated_csi = None
110114
111115# Test normalization
112- print ("\n === CSI Normalization Examples ===" )
113-
114116# MinMax normalization
115117normalized_minmax = model .normalize_csi (csi_valid , method = "minmax" , target_range = (0 , 1 ))
116- print ( f" Original CSI range: [{ csi_valid .min ():.2f} , { csi_valid .max ():.2f} ]" )
117- print ( f" MinMax normalized range: [{ normalized_minmax .min ():.2f} , { normalized_minmax .max ():.2f} ]" )
118+ # Original CSI range: [{csi_valid.min():.2f}, {csi_valid.max():.2f}]
119+ # MinMax normalized range: [{normalized_minmax.min():.2f}, {normalized_minmax.max():.2f}]
118120
119121# Z-score normalization
120122normalized_zscore = model .normalize_csi (csi_valid , method = "zscore" )
121- print ( f" Z-score normalized mean: { normalized_zscore .mean ():.4f} , std: { normalized_zscore .std ():.4f} " )
123+ # Z-score normalized mean: {normalized_zscore.mean():.4f}, std: {normalized_zscore.std():.4f}
122124
123125# %%
124126# Working with the AFModule
125127# -------------------------
126128# The AFModule has been updated to use ChannelAwareBaseModel. Let's demonstrate its usage.
127129
128- print ("\n === AFModule (Channel-Aware) Example ===" )
129-
130130# Create AFModule
131131N = 64 # Number of feature channels
132132csi_length = 1 # CSI vector length
133133af_module = AFModule (N = N , csi_length = csi_length )
134134
135135# Create feature maps (4D tensor for image-like data)
136136feature_maps = torch .randn (batch_size , N , 8 , 8 )
137- print ( f" Feature maps shape: { feature_maps .shape } " )
137+ # Feature maps shape: {feature_maps.shape}
138138
139139# Create CSI for AFModule (SNR values in dB)
140140snr_values = torch .tensor ([10.0 , 15.0 , 5.0 , 20.0 , 12.0 , 8.0 , 18.0 , 14.0 ])
141141csi_af = snr_values .unsqueeze (1 ) # Shape: [batch_size, 1]
142- print ( f" CSI shape for AFModule: { csi_af .shape } " )
142+ # CSI shape for AFModule: {csi_af.shape}
143143
144144# Apply AFModule
145145with torch .no_grad ():
146146 modulated_features = af_module (feature_maps , csi = csi_af )
147147
148- print ( f" Modulated features shape: { modulated_features .shape } " )
149- print ( f" Feature modulation factor range: [{ (modulated_features / feature_maps ).min ():.3f} , { (modulated_features / feature_maps ).max ():.3f} ]" )
148+ # Modulated features shape: {modulated_features.shape}
149+ # Feature modulation factor range: [{(modulated_features/feature_maps).min():.3f}, {(modulated_features/feature_maps).max():.3f}]
150150
151151# %%
152152# CSI Feature Extraction
153153# ----------------------
154154# The base class provides methods to extract useful features from CSI.
155155
156- print ("\n === CSI Feature Extraction ===" )
157-
158156csi_features = model .extract_csi_features (csi_valid )
159- print ( " Extracted CSI features:" )
157+ # Extracted CSI features:
160158for feature_name , feature_value in csi_features .items ():
161159 if isinstance (feature_value , torch .Tensor ):
162160 if feature_value .numel () == 1 :
163- print (f" { feature_name } : { feature_value .item ():.4f} " )
161+ # Single-value feature: {feature_name}: {feature_value.item():.4f}
162+ pass
164163 else :
165- print (f" { feature_name } : { feature_value .tolist ()} " )
164+ # Multi-value feature: {feature_name}: {feature_value.tolist()}
165+ pass
166166
167167# %%
168168# Visualization of CSI Effects
169169# ----------------------------
170170# Let's visualize how different CSI values affect model outputs.
171171
172- print ("\n === Visualizing CSI Effects ===" )
173-
174172# Generate a range of CSI values
175173snr_range = torch .linspace (- 5 , 25 , 31 ) # SNR from -5 to 25 dB
176174quality_factor = torch .ones_like (snr_range ) * 0.5 # Fixed quality factor
@@ -204,7 +202,7 @@ def forward(self, x: torch.Tensor, csi: torch.Tensor, *args, **kwargs) -> torch.
204202# Plot 2: First few output dimensions vs SNR
205203actual_output_dim = outputs .shape [1 ] if outputs .ndim > 1 else 1
206204for i in range (min (4 , actual_output_dim )):
207- axes [0 , 1 ].plot (snr_range .numpy (), outputs [:, i ], label = f "Dim { i + 1 } " )
205+ axes [0 , 1 ].plot (snr_range .numpy (), outputs [:, i ], label = "Dim " + str ( i + 1 ) )
208206axes [0 , 1 ].set_xlabel ("SNR (dB)" )
209207axes [0 , 1 ].set_ylabel ("Output Value" )
210208axes [0 , 1 ].set_title ("Output Dimensions vs CSI (SNR)" )
@@ -260,8 +258,6 @@ def forward(self, x: torch.Tensor, csi: torch.Tensor, *args, **kwargs) -> torch.
260258# ---------------------------------
261259# Demonstrate how to extract and use CSI from channel outputs.
262260
263- print ("\n === Integration with Channels ===" )
264-
265261# Create channels that might provide CSI
266262awgn_channel = AWGNChannel (snr_db = 15.0 )
267263fading_channel = FlatFadingChannel (fading_type = "rayleigh" , coherence_time = 50 , snr_db = 10.0 )
@@ -273,24 +269,24 @@ def forward(self, x: torch.Tensor, csi: torch.Tensor, *args, **kwargs) -> torch.
273269awgn_output = awgn_channel (test_signal )
274270fading_output = fading_channel (test_signal )
275271
276- print ( f" AWGN channel output shape: { awgn_output .shape } " )
277- print ( f" Fading channel output shape: { fading_output .shape } " )
272+ # AWGN channel output shape: {awgn_output.shape}
273+ # Fading channel output shape: {fading_output.shape}
278274
279275# Try to extract CSI (channels might not provide it directly)
280276awgn_csi = model .extract_csi_from_channel_output (awgn_output )
281277fading_csi = model .extract_csi_from_channel_output (fading_output )
282278
283- print ( f" Extracted CSI from AWGN: { awgn_csi } " )
284- print ( f" Extracted CSI from Fading: { fading_csi } " )
279+ # Extracted CSI from AWGN: {awgn_csi}
280+ # Extracted CSI from Fading: {fading_csi}
285281
286282# Create CSI manually based on channel properties
287283if hasattr (awgn_channel , "snr_db" ):
288284 manual_awgn_csi = torch .full ((batch_size , 1 ), awgn_channel .snr_db )
289- print (f "Manual AWGN CSI: { manual_awgn_csi .mean ().item ():.1f } dB" )
285+ print ("Manual AWGN CSI: " + str ( manual_awgn_csi .mean ().item ()) + " dB" )
290286
291287if hasattr (fading_channel , "snr_db" ):
292288 manual_fading_csi = torch .full ((batch_size , 1 ), fading_channel .snr_db )
293- print (f "Manual Fading CSI: { manual_fading_csi .mean ().item ():.1f } dB" )
289+ print ("Manual Fading CSI: " + str ( manual_fading_csi .mean ().item ()) + " dB" )
294290
295291# %%
296292# Best Practices Summary
0 commit comments