Skip to content

Commit f1f838d

Browse files
drvinceknightmarcharper
authored andcommitted
Add type hints.
This fixes the two errors: ``` axelrod/fingerprint.py:444: error: Function "numpy.array" is not valid as a type 22 axelrod/fingerprint.py:444: note: Perhaps you need "Callable[...]" or a callback protocol? 23 Found 1 error in 1 file (checked 1 source file) ``` and: ``` axelrod/strategies/ann.py:125: error: Incompatible types in assignment (expression has type "ndarray", variable has type "List[int]") 34 Found 1 error in 1 file (checked 1 source file) ``` note that in `fingerprint.py` I'm using the `ArrayType` type but that when I tried to use that in `ann.py` a lot of other things broke.
1 parent 6ac633a commit f1f838d

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

axelrod/fingerprint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from axelrod.strategy_transformers import DualTransformer, JossAnnTransformer
1717
from mpl_toolkits.axes_grid1 import make_axes_locatable
1818

19+
from numpy.typing import ArrayLike
20+
1921
Point = namedtuple("Point", "x y")
2022

2123

@@ -188,7 +190,7 @@ def _generate_data(interactions: dict, points: list, edges: list) -> dict:
188190
return point_scores
189191

190192

191-
def _reshape_data(data: dict, points: list, size: int) -> np.ndarray:
193+
def _reshape_data(data: dict, points: list, size: int) -> ArrayLike:
192194
"""Shape the data so that it can be plotted easily.
193195
194196
Parameters
@@ -441,7 +443,7 @@ def fingerprint(
441443
filename: str = None,
442444
progress_bar: bool = True,
443445
seed: int = None,
444-
) -> np.array:
446+
) -> ArrayLike:
445447
"""Creates a spatial tournament to run the necessary matches to obtain
446448
fingerprint data.
447449

axelrod/strategies/ann.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def num_weights(num_features, num_hidden):
2222
return size
2323

2424

25-
def compute_features(player: Player, opponent: Player) -> List[int]:
25+
def compute_features(player: Player, opponent: Player) -> np.ndarray:
2626
"""
2727
Compute history features for Neural Network:
2828
* Opponent's first move is C
@@ -91,7 +91,7 @@ def compute_features(player: Player, opponent: Player) -> List[int]:
9191
total_player_c = player.cooperations
9292
total_player_d = player.defections
9393

94-
return [
94+
return np.array((
9595
opponent_first_c,
9696
opponent_first_d,
9797
opponent_second_c,
@@ -109,20 +109,19 @@ def compute_features(player: Player, opponent: Player) -> List[int]:
109109
total_player_c,
110110
total_player_d,
111111
len(player.history),
112-
]
112+
))
113113

114114

115115
def activate(
116116
bias: List[float],
117117
hidden: List[float],
118118
output: List[float],
119-
inputs: List[int],
119+
inputs: np.ndarray,
120120
) -> float:
121121
"""
122122
Compute the output of the neural network:
123123
output = relu(inputs * hidden_weights + bias) * output_weights
124124
"""
125-
inputs = np.array(inputs)
126125
hidden_values = bias + np.dot(hidden, inputs)
127126
hidden_values = relu(hidden_values)
128127
output_value = np.dot(hidden_values, output)

0 commit comments

Comments
 (0)