@@ -28,7 +28,10 @@ def __init__(self):
2828 self .fc = nn .Linear (10 , 20 )
2929
3030 def forward (self , x , * args , ** kwargs ):
31- return self .fc (x )
31+ continuous_output = self .fc (x )
32+ # Convert to binary output as some modulators (e.g., PSKModulator) expect binary inputs
33+ binary_output = (continuous_output > 0 ).float ()
34+ return binary_output
3235
3336
3437class SimpleDecoder (BaseModel ):
@@ -206,10 +209,10 @@ def test_basic_forward(self, basic_channel_code_model):
206209 # Process the input through the model
207210 output = model (input_data )
208211
209- # Check the output type and value (should be identical with identity components after binary conversion )
212+ # Check the output type and value (should be identical to input with identity components)
210213 assert isinstance (output , torch .Tensor )
211- binary_input_data = ( input_data > 0 ). float ()
212- assert torch .allclose (output , binary_input_data )
214+ # With identity components, the output should be the same as the input
215+ assert torch .allclose (output , input_data )
213216
214217 def test_forward_perfect_channel (self , simple_channel_code_model ):
215218 """Test forward pass with custom components and a perfect channel."""
0 commit comments