Skip to content

Commit 2234b82

Browse files
committed
Refactor Reed-Muller decoder for improved tensor handling and error checking; update tests to reflect parameter name changes in ReedMullerCodeEncoder.
1 parent e8e968b commit 2234b82

File tree

3 files changed

+56
-30
lines changed

3 files changed

+56
-30
lines changed

kaira/models/fec/decoders/reed_muller_decoder.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -241,17 +241,19 @@ def decode_block(r_block):
241241
errors = torch.zeros_like(r_block) if return_errors else None
242242

243243
for i in range(batch_size):
244-
# Get the current received word
245-
r = r_block[i]
244+
# Get the current received word - ensure it's a 1D tensor
245+
if r_block.dim() == 3: # Handle the case when r_block has shape [batch, 1, code_length]
246+
r = r_block[i, 0, :]
247+
else: # Handle the case when r_block has shape [batch, code_length]
248+
r = r_block[i, :]
246249

250+
"""
247251
# Convert to binary for hard decoding or compute hard decisions for soft decoding
248252
if self.input_type == "hard":
249253
bx = r.clone()
250254
else: # self.input_type == "soft"
251255
bx = (r < 0).to(torch.int)
252-
253-
# Original received bits (for error calculation)
254-
# original_bits = bx.clone() if return_errors else None
256+
"""
255257

256258
# Decode using Reed algorithm
257259
u_hat = torch.zeros(self.code_dimension, dtype=torch.int, device=received.device)
@@ -266,10 +268,20 @@ def decode_block(r_block):
266268
# Calculate checksums for each group in the partition
267269
checksums = []
268270
for group in partition:
271+
# Ensure the group indices are valid
272+
valid_indices = group[group < r.shape[0]]
273+
if len(valid_indices) == 0:
274+
continue
275+
269276
# Take relevant positions and compute parity
270-
group_bits = bx[group]
277+
# Use indexing to select elements from the 1D tensor
278+
group_bits = r[valid_indices].to(torch.int)
271279
checksum = torch.sum(group_bits) % 2
272-
checksums.append(checksum)
280+
checksums.append(checksum.item()) # Use .item() to convert tensor to scalar
281+
282+
# Skip if no valid checksums
283+
if not checksums:
284+
continue
273285

274286
# Convert to tensor
275287
checksums = torch.tensor(checksums, device=received.device)
@@ -284,17 +296,26 @@ def decode_block(r_block):
284296
min_reliabilities = []
285297

286298
for group in partition:
299+
# Ensure the group indices are valid
300+
valid_indices = group[group < r.shape[0]]
301+
if len(valid_indices) == 0:
302+
continue
303+
287304
# Take relevant positions
288-
group_bits = bx[group]
289-
group_reliabilities = torch.abs(r[group])
305+
group_bits = (r[valid_indices] < 0).to(torch.int)
306+
group_reliabilities = torch.abs(r[valid_indices])
290307

291308
# Compute parity of hard decisions
292309
checksum = torch.sum(group_bits) % 2
293-
checksums.append(checksum)
310+
checksums.append(checksum.item()) # Use .item() to convert tensor to scalar
294311

295312
# Find minimum reliability in this group
296313
min_reliability = torch.min(group_reliabilities)
297-
min_reliabilities.append(min_reliability)
314+
min_reliabilities.append(min_reliability.item()) # Use .item() to convert tensor to scalar
315+
316+
# Skip if no valid checksums
317+
if not checksums:
318+
continue
298319

299320
# Convert to tensors
300321
checksums = torch.tensor(checksums, device=received.device)
@@ -306,18 +327,14 @@ def decode_block(r_block):
306327
# Make decision
307328
u_hat[j] = (decision_var < 0).to(torch.int)
308329

309-
# Cancel the effect of this bit from the received word
310-
# In a complete implementation, this would use the generator matrix
311-
# bx ^= u_hat[j] * self.encoder.generator_matrix[j]
312-
313330
# Store the decoded message
314331
decoded[i] = u_hat
315332

316333
# Compute error pattern if needed
317334
if return_errors:
318335
# Re-encode the message to get the correct codeword
319-
correct_codeword = self.encoder(u_hat.unsqueeze(0)).squeeze(0)
320-
errors[i] = (r.to(torch.int) != correct_codeword).to(torch.int)
336+
correct_codeword = self.encoder(u_hat.float().unsqueeze(0)).squeeze(0)
337+
errors[i] = (r.to(torch.int) != correct_codeword.to(torch.int)).to(torch.int)
321338

322339
return (decoded, errors) if return_errors else decoded
323340

kaira/models/fec/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def apply_blockwise(x: torch.Tensor, block_size: int, fn: Callable) -> torch.Ten
131131
a transformed tensor preserving the batch dimensions
132132
133133
Returns:
134-
Tensor with transformed blocks
134+
Tensor with transformed blocks or tuple of tensors if fn returns a tuple
135135
136136
Raises:
137137
AssertionError: If the last dimension is not divisible by block_size
@@ -152,5 +152,14 @@ def apply_blockwise(x: torch.Tensor, block_size: int, fn: Callable) -> torch.Ten
152152
# Apply function along the last dimension (block)
153153
result = fn(x_reshaped)
154154

155-
# Flatten the result back to original structure
156-
return result.view(*leading_dims, -1)
155+
# Check if the result is a tuple (like when return_errors=True)
156+
if isinstance(result, tuple):
157+
# Process each part of the tuple independently
158+
processed_results = []
159+
for res_part in result:
160+
# Flatten each part back to original structure
161+
processed_results.append(res_part.view(*leading_dims, -1))
162+
return tuple(processed_results)
163+
else:
164+
# Flatten the result back to original structure
165+
return result.view(*leading_dims, -1)

tests/models/fec/decoders/test_reed_muller_decoder.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TestReedMullerDecoder:
1313
def test_initialization(self):
1414
"""Test initialization with valid parameters."""
1515
# Create a Reed-Muller encoder for RM(1,3)
16-
encoder = ReedMullerCodeEncoder(r=1, m=3)
16+
encoder = ReedMullerCodeEncoder(order=1, length_param=3)
1717

1818
# Initialize the decoder with default parameters (hard-decision)
1919
decoder = ReedMullerDecoder(encoder=encoder)
@@ -35,7 +35,7 @@ def test_initialization(self):
3535
def test_generate_reed_partitions(self):
3636
"""Test generation of Reed partitions."""
3737
# Create a Reed-Muller encoder for RM(1,3)
38-
encoder = ReedMullerCodeEncoder(r=1, m=3)
38+
encoder = ReedMullerCodeEncoder(order=1, length_param=3)
3939
decoder = ReedMullerDecoder(encoder=encoder)
4040

4141
# Generate Reed partitions
@@ -51,7 +51,7 @@ def test_generate_reed_partitions(self):
5151
def test_decoding_no_errors_hard_decision(self):
5252
"""Test hard-decision decoding with no errors."""
5353
# Create a Reed-Muller encoder for RM(1,3)
54-
encoder = ReedMullerCodeEncoder(r=1, m=3)
54+
encoder = ReedMullerCodeEncoder(order=1, length_param=3)
5555
decoder = ReedMullerDecoder(encoder=encoder, input_type="hard")
5656

5757
# Create a message and encode it
@@ -71,7 +71,7 @@ def test_decoding_no_errors_hard_decision(self):
7171
def test_decoding_with_errors_hard_decision(self):
7272
"""Test hard-decision decoding with errors."""
7373
# Create a Reed-Muller encoder for RM(1,3)
74-
encoder = ReedMullerCodeEncoder(r=1, m=3)
74+
encoder = ReedMullerCodeEncoder(order=1, length_param=3)
7575
decoder = ReedMullerDecoder(encoder=encoder, input_type="hard")
7676

7777
# Create a message and encode it
@@ -91,7 +91,7 @@ def test_decoding_with_errors_hard_decision(self):
9191
def test_decoding_soft_decision(self):
9292
"""Test soft-decision decoding."""
9393
# Create a Reed-Muller encoder for RM(1,3)
94-
encoder = ReedMullerCodeEncoder(r=1, m=3)
94+
encoder = ReedMullerCodeEncoder(order=1, length_param=3)
9595
decoder = ReedMullerDecoder(encoder=encoder, input_type="soft")
9696

9797
# Create a message and encode it
@@ -117,7 +117,7 @@ def test_decoding_soft_decision(self):
117117
def test_decoding_with_return_errors(self):
118118
"""Test decoding with return_errors=True."""
119119
# Create a Reed-Muller encoder for RM(1,3)
120-
encoder = ReedMullerCodeEncoder(r=1, m=3)
120+
encoder = ReedMullerCodeEncoder(order=1, length_param=3)
121121
decoder = ReedMullerDecoder(encoder=encoder)
122122

123123
# Create a message and encode it
@@ -139,7 +139,7 @@ def test_decoding_with_return_errors(self):
139139
def test_decoding_with_batch_dimension(self):
140140
"""Test decoding with batch dimension."""
141141
# Create a Reed-Muller encoder for RM(1,3)
142-
encoder = ReedMullerCodeEncoder(r=1, m=3)
142+
encoder = ReedMullerCodeEncoder(order=1, length_param=3)
143143
decoder = ReedMullerDecoder(encoder=encoder)
144144

145145
# Create messages and encode them
@@ -166,7 +166,7 @@ def test_decoding_with_batch_dimension(self):
166166
def test_invalid_input_dimensions(self):
167167
"""Test decoding with invalid input dimensions."""
168168
# Create a Reed-Muller encoder for RM(1,3)
169-
encoder = ReedMullerCodeEncoder(r=1, m=3)
169+
encoder = ReedMullerCodeEncoder(order=1, length_param=3)
170170
decoder = ReedMullerDecoder(encoder=encoder)
171171

172172
# Create a received word with invalid length
@@ -180,15 +180,15 @@ def test_invalid_input_dimensions(self):
180180
def test_multiple_rm_parameters(self):
181181
"""Test with different Reed-Muller code parameters."""
182182
# Test with RM(0,3) - repetition code
183-
encoder_rm03 = ReedMullerCodeEncoder(r=0, m=3)
183+
encoder_rm03 = ReedMullerCodeEncoder(order=0, length_param=3)
184184
decoder_rm03 = ReedMullerDecoder(encoder=encoder_rm03)
185185

186186
# For RM(0,3), dimension = 1, length = 8
187187
assert encoder_rm03.code_dimension == 1
188188
assert encoder_rm03.code_length == 8
189189

190190
# Test with RM(1,4) - first-order code
191-
encoder_rm14 = ReedMullerCodeEncoder(r=1, m=4)
191+
encoder_rm14 = ReedMullerCodeEncoder(order=1, length_param=4)
192192
decoder_rm14 = ReedMullerDecoder(encoder=encoder_rm14)
193193

194194
# For RM(1,4), dimension = 5, length = 16

0 commit comments

Comments
 (0)