Skip to content

Commit 4a8fba1

Browse files
committed
Adding coverage on attention player
1 parent fd94f58 commit 4a8fba1

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

axelrod/tests/strategies/test_attention.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,27 @@ def test_compute_features(self):
4242
self.assertEqual(features[3].item(), GameState.DefectCooperate)
4343
self.assertEqual(features[4].item(), GameState.CooperateDefect)
4444

45+
def test_compute_features_right_pad(self):
46+
"""Test that features are computed correctly."""
47+
player = axl.MockPlayer(actions=[C, D, C, D])
48+
opponent = axl.MockPlayer(actions=[D, C, C, D])
49+
# Play the actions to populate history
50+
match = axl.Match((player, opponent), turns=4)
51+
match.play()
52+
53+
features = compute_features(player, opponent, True)
54+
55+
# Check the shape and type
56+
self.assertIsInstance(features, torch.Tensor)
57+
self.assertEqual(features.shape, (MEMORY_LENGTH + 1,))
58+
59+
# Check specific values (CLS token and game states)
60+
self.assertEqual(features[0].item(), 0) # CLS token
61+
self.assertEqual(features[1].item(), GameState.DefectDefect)
62+
self.assertEqual(features[2].item(), GameState.CooperateCooperate)
63+
self.assertEqual(features[3].item(), GameState.DefectCooperate)
64+
self.assertEqual(features[4].item(), GameState.CooperateDefect)
65+
4566
def test_actions_to_game_state(self):
4667
"""Test the mapping from actions to game states."""
4768
self.assertEqual(
@@ -52,6 +73,13 @@ def test_actions_to_game_state(self):
5273
self.assertEqual(actions_to_game_state(D, D), GameState.DefectDefect)
5374

5475

76+
class TestAttention(unittest.TestCase):
77+
def test_initilization(self):
78+
"""Test that the model is initialized correctly."""
79+
player = axl.Attention()
80+
self.assertIsInstance(player.model, PlayerModel)
81+
82+
5583
class TestEvolvedAttention(TestPlayer):
5684
name = "EvolvedAttention"
5785
player = axl.EvolvedAttention
@@ -66,7 +94,7 @@ class TestEvolvedAttention(TestPlayer):
6694
}
6795

6896
def test_model_initialization(self):
69-
"""Test that the model is initialized with pretrained weights."""
97+
"""Test that the model is initialized correctly."""
7098
player = self.player()
7199
self.assertIsInstance(player.model, PlayerModel)
72100

0 commit comments

Comments
 (0)