Skip to content

Commit 88fbd97

Browse files
committed
Refactor SimpleEncoder to convert continuous output to binary and update test assertions for identity components
1 parent e61191e commit 88fbd97

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/models/test_models_channel_code.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ def __init__(self):
2828
self.fc = nn.Linear(10, 20)
2929

3030
def forward(self, x, *args, **kwargs):
31-
return self.fc(x)
31+
continuous_output = self.fc(x)
32+
# Convert to binary output as some modulators (e.g., PSKModulator) expect binary inputs
33+
binary_output = (continuous_output > 0).float()
34+
return binary_output
3235

3336

3437
class SimpleDecoder(BaseModel):
@@ -206,10 +209,10 @@ def test_basic_forward(self, basic_channel_code_model):
206209
# Process the input through the model
207210
output = model(input_data)
208211

209-
# Check the output type and value (should be identical with identity components after binary conversion)
212+
# Check the output type and value (should be identical to input with identity components)
210213
assert isinstance(output, torch.Tensor)
211-
binary_input_data = (input_data > 0).float()
212-
assert torch.allclose(output, binary_input_data)
214+
# With identity components, the output should be the same as the input
215+
assert torch.allclose(output, input_data)
213216

214217
def test_forward_perfect_channel(self, simple_channel_code_model):
215218
"""Test forward pass with custom components and a perfect channel."""

0 commit comments

Comments
 (0)