@@ -42,6 +42,27 @@ def test_compute_features(self):
4242 self .assertEqual (features [3 ].item (), GameState .DefectCooperate )
4343 self .assertEqual (features [4 ].item (), GameState .CooperateDefect )
4444
45+ def test_compute_features_right_pad (self ):
46+ """Test that features are computed correctly."""
47+ player = axl .MockPlayer (actions = [C , D , C , D ])
48+ opponent = axl .MockPlayer (actions = [D , C , C , D ])
49+ # Play the actions to populate history
50+ match = axl .Match ((player , opponent ), turns = 4 )
51+ match .play ()
52+
53+ features = compute_features (player , opponent , True )
54+
55+ # Check the shape and type
56+ self .assertIsInstance (features , torch .Tensor )
57+ self .assertEqual (features .shape , (MEMORY_LENGTH + 1 ,))
58+
59+ # Check specific values (CLS token and game states)
60+ self .assertEqual (features [0 ].item (), 0 ) # CLS token
61+ self .assertEqual (features [1 ].item (), GameState .DefectDefect )
62+ self .assertEqual (features [2 ].item (), GameState .CooperateCooperate )
63+ self .assertEqual (features [3 ].item (), GameState .DefectCooperate )
64+ self .assertEqual (features [4 ].item (), GameState .CooperateDefect )
65+
4566 def test_actions_to_game_state (self ):
4667 """Test the mapping from actions to game states."""
4768 self .assertEqual (
@@ -52,6 +73,13 @@ def test_actions_to_game_state(self):
5273 self .assertEqual (actions_to_game_state (D , D ), GameState .DefectDefect )
5374
5475
76+ class TestAttention (unittest .TestCase ):
77+ def test_initilization (self ):
78+ """Test that the model is initialized correctly."""
79+ player = axl .Attention ()
80+ self .assertIsInstance (player .model , PlayerModel )
81+
82+
5583class TestEvolvedAttention (TestPlayer ):
5684 name = "EvolvedAttention"
5785 player = axl .EvolvedAttention
@@ -66,7 +94,7 @@ class TestEvolvedAttention(TestPlayer):
6694 }
6795
6896 def test_model_initialization (self ):
69- """Test that the model is initialized with pretrained weights ."""
97+ """Test that the model is initialized correctly ."""
7098 player = self .player ()
7199 self .assertIsInstance (player .model , PlayerModel )
72100
0 commit comments