Skip to content

Commit ee13777

Browse files
gaffney2010marcharper
authored andcommitted
Algorithm for Memory Depth in FSM (#1233)
* Clean up docstrings, to match parameter-listing to actual parameters. * Response to meatballs' feedback. * Added a get_memory_from_transitions function for FSM. * Change default verbosity of get_memory_from_transitions. * Fixed build error, and made less verbose. * Change the format of transitions in get_memory_from_transitions to dict. * Added types to get_memory_from_transitions. * Make ActionChains faster. * Replaced get_accessible_transitions with a much faster version. * Updated tests for new memory classifiers. * Fixed mypy and doctest errors. * Updated metastrategies for new finite set. * Moved blocks/added comments for readability. * More specific typing. * Responded to Marc's comments. * Import List from typing. * Change DefaultDict type and added tit_for_five_tat test. * Remove type on all_memits * Remove more typing. * Fixed type on longest_path argument. * Responding to Marc's comments. * Fixed Tuple annotation. * Add a memory test to default FSM test. * Responded to some of drvinceknight comments on memory depth. * Fixed some errors. * Moved FSM memory functions to separate top-level file compute_finite_state_machine_memory * Added additional topic documentation for FSM/memory. * Fix code in new documentation. * Move unit tests for compute FSM memory and add order_memit_tuple. * Minor changes to meta_strategies doc. * Update bibliography.rst * Undid changes to usually coop/def. Move to different commit. * Delete old comment. * Remove memory tests from FSM test file; these have already been copied. * Add memory library back into FSM test. * Minor fixes to compute memory tests.
1 parent 2f24224 commit ee13777

File tree

9 files changed

+742
-62
lines changed

9 files changed

+742
-62
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
from axelrod.action import Action
2+
from collections import defaultdict, namedtuple
3+
from typing import DefaultDict, Iterator, Dict, Tuple, Set, List
4+
5+
C, D = Action.C, Action.D
6+
7+
Transition = namedtuple(
8+
"Transition", ["state", "last_opponent_action", "next_state", "next_action"]
9+
)
10+
TransitionDict = Dict[Tuple[int, Action], Tuple[int, Action]]
11+
12+
13+
class Memit(object):
14+
"""
15+
Memit = unit of memory.
16+
17+
This represents the amount of memory that we gain with each new piece of
18+
history. It includes a state, our_response that we make on our way into that
19+
state (in_act), and the opponent's action that makes us move out of that state
20+
(out_act).
21+
22+
For example, for this finite state machine:
23+
(0, C, 0, C),
24+
(0, D, 1, C),
25+
(1, C, 0, D),
26+
(1, D, 0, D)
27+
28+
Has the memits:
29+
(C, 0, C),
30+
(C, 0, D),
31+
(D, 0, C),
32+
(D, 0, D),
33+
(C, 1, C),
34+
(C, 1, D)
35+
"""
36+
37+
def __init__(self, in_act: Action, state: int, out_act: Action):
38+
self.in_act = in_act
39+
self.state = state
40+
self.out_act = out_act
41+
42+
def __repr__(self) -> str:
43+
return "{}, {}, {}".format(self.in_act, self.state, self.out_act)
44+
45+
def __hash__(self):
46+
return hash(repr(self))
47+
48+
def __eq__(self, other_memit) -> bool:
49+
"""In action and out actions are the same."""
50+
return (
51+
self.in_act == other_memit.in_act
52+
and self.out_act == other_memit.out_act
53+
)
54+
55+
def __lt__(self, other_memit) -> bool:
56+
return repr(self) < repr(other_memit)
57+
58+
59+
MemitPair = Tuple[Memit, Memit]
60+
61+
62+
def ordered_memit_tuple(x: Memit, y: Memit) -> tuple:
63+
"""Returns a tuple of x in y, sorted so that (x, y) are viewed as the
64+
same as (y, x).
65+
"""
66+
if x < y:
67+
return (x, y)
68+
else:
69+
return (y, x)
70+
71+
72+
def transition_iterator(transitions: TransitionDict) -> Iterator[Transition]:
73+
"""Changes the transition dictionary into a iterator on namedtuples."""
74+
for k, v in transitions.items():
75+
yield Transition(k[0], k[1], v[0], v[1])
76+
77+
78+
def get_accessible_transitions(
79+
transitions: TransitionDict, initial_state: int
80+
) -> TransitionDict:
81+
"""Gets all transitions from the list that can be reached from the
82+
initial_state.
83+
"""
84+
# Initial dict of edges between states and a dict of visited status for each
85+
# of the states.
86+
edge_dict = defaultdict(list) # type: DefaultDict[int, List[int]]
87+
visited = dict()
88+
for trans in transition_iterator(transitions):
89+
visited[trans.state] = False
90+
edge_dict[trans.state].append(trans.next_state)
91+
# Keep track of states that can be accessed.
92+
accessible_states = [initial_state]
93+
94+
state_queue = [initial_state]
95+
visited[initial_state] = True
96+
# While there are states in the queue, visit all its children, adding each
97+
# to the accesible_states. [A basic breadth-first search.]
98+
while len(state_queue) > 0:
99+
state = state_queue.pop()
100+
for successor in edge_dict[state]:
101+
# Don't process the same state twice.
102+
if not visited[successor]:
103+
visited[successor] = True
104+
state_queue.append(successor)
105+
accessible_states.append(successor)
106+
107+
# Now for each transition in the passed TransitionDict, copy the transition
108+
# to accessible_transitions if and only if the starting state is accessible,
109+
# as determined above.
110+
accessible_transitions = dict()
111+
for trans in transition_iterator(transitions):
112+
if trans.state in accessible_states:
113+
accessible_transitions[
114+
(trans.state, trans.last_opponent_action)
115+
] = (trans.next_state, trans.next_action)
116+
117+
return accessible_transitions
118+
119+
120+
def longest_path(
121+
edges: DefaultDict[MemitPair, Set[MemitPair]], starting_at: MemitPair
122+
) -> int:
123+
"""Returns the number of nodes in the longest path that starts at the given
124+
node. Returns infinity if a loop is encountered.
125+
"""
126+
visited = dict()
127+
for source, destinations in edges.items():
128+
visited[source] = False
129+
for destination in destinations:
130+
visited[destination] = False
131+
132+
# This is what we'll recurse on. visited dict is shared between calls.
133+
def recurse(at_node):
134+
visited[at_node] = True
135+
record = 1 # Count the nodes, not the edges.
136+
for successor in edges[at_node]:
137+
if visited[successor]:
138+
return float("inf")
139+
successor_length = recurse(successor)
140+
if successor_length == float("inf"):
141+
return float("inf")
142+
if record < successor_length + 1:
143+
record = successor_length + 1
144+
return record
145+
146+
return recurse(starting_at)
147+
148+
149+
def get_memory_from_transitions(
150+
transitions: TransitionDict,
151+
initial_state: int = None,
152+
all_actions: Tuple[Action, Action] = (C, D),
153+
) -> int:
154+
"""This function calculates the memory of an FSM from the transitions.
155+
156+
Assume that transitions are a dict with entries like
157+
(state, last_opponent_action): (next_state, next_action)
158+
159+
We first break down the transitions into memits (see above). We also create
160+
a graph of memits, where the successor to a given memit are all possible
161+
memits that could occur in the memory immediately before the given memit.
162+
163+
Then we pair up memits with different states, but same in and out actions.
164+
These represent points in time that we can't determine which state we're in.
165+
We also create a graph of memit-pairs, where memit-pair, Y, succeeds a
166+
memit-pair, X, if the two memits in X are succeeded by the two memits in Y.
167+
These edges reperesent consecutive points in time that we can't determine
168+
which state we're in.
169+
170+
Then for all memit-pairs that disagree, in the sense that they imply
171+
different next_action, we find the longest chain starting at that
172+
memit-pair. [If a loop is encountered then this will be infinite.] We take
173+
the maximum over all such memit-pairs. This represents the longest possible
174+
chain of memory for which we wouldn't know what to do next. We return this.
175+
"""
176+
# If initial_state is set, use this to determine which transitions are
177+
# reachable from the initial_state and restrict to those.
178+
if initial_state is not None:
179+
transitions = get_accessible_transitions(transitions, initial_state)
180+
181+
# Get the incoming actions for each state.
182+
incoming_action_by_state = defaultdict(
183+
set
184+
) # type: DefaultDict[int, Set[Action]]
185+
for trans in transition_iterator(transitions):
186+
incoming_action_by_state[trans.next_state].add(trans.next_action)
187+
188+
# Keys are starting memit, and values are all possible terminal memit.
189+
# Will walk backwards through the graph.
190+
memit_edges = defaultdict(set) # type: DefaultDict[Memit, Set[Memit]]
191+
for trans in transition_iterator(transitions):
192+
# Since all actions are out-paths for each state, add all of these.
193+
# That is to say that the opponent could do anything
194+
for out_action in all_actions:
195+
# More recent in action history
196+
starting_node = Memit(
197+
trans.next_action, trans.next_state, out_action
198+
)
199+
# All incoming paths to current state
200+
for in_action in incoming_action_by_state[trans.state]:
201+
# Less recent in action history
202+
ending_node = Memit(
203+
in_action, trans.state, trans.last_opponent_action
204+
)
205+
memit_edges[starting_node].add(ending_node)
206+
207+
all_memits = list(memit_edges.keys())
208+
209+
pair_nodes = set()
210+
pair_edges = defaultdict(
211+
set
212+
) # type: DefaultDict[MemitPair, Set[MemitPair]]
213+
# Loop through all pairs of memits.
214+
for x, y in [(x, y) for x in all_memits for y in all_memits]:
215+
if x == y and x.state == y.state:
216+
continue
217+
if x != y:
218+
continue
219+
220+
# If the memits match, then the strategy can't tell the difference
221+
# between the states. We call this a pair of matched memits (or just a
222+
# pair).
223+
pair_nodes.add(ordered_memit_tuple(x, y))
224+
# When two memits in matched pair have successors that are also matched,
225+
# then we draw an edge. This represents consecutive historical times
226+
# that we can't tell which state we're in.
227+
for x_successor in memit_edges[x]:
228+
for y_successor in memit_edges[y]:
229+
if x_successor == y_successor:
230+
pair_edges[ordered_memit_tuple(x, y)].add(
231+
ordered_memit_tuple(x_successor, y_successor)
232+
)
233+
234+
if len(pair_nodes) == 0:
235+
# If there are no pair of tied memits, then either no memits are needed
236+
# to break a tie (i.e. all next_actions are the same) or the first memit
237+
# breaks a tie (i.e. memory 1)
238+
next_action_set = set()
239+
for trans in transition_iterator(transitions):
240+
next_action_set.add(trans.next_action)
241+
if len(next_action_set) == 1:
242+
return 0
243+
return 1
244+
245+
# Get next_action for each memit. Used to decide if they are in conflict,
246+
# because we only have undecidability if next_action doesn't match.
247+
next_action_by_memit = dict()
248+
for trans in transition_iterator(transitions):
249+
for in_action in incoming_action_by_state[trans.state]:
250+
memit_key = Memit(
251+
in_action, trans.state, trans.last_opponent_action
252+
)
253+
next_action_by_memit[memit_key] = trans.next_action
254+
255+
# Calculate the longest path.
256+
record = 0
257+
for pair in pair_nodes:
258+
if next_action_by_memit[pair[0]] != next_action_by_memit[pair[1]]:
259+
# longest_path is the longest chain of tied states. We add one to
260+
# get the memory length needed to break all ties.
261+
path_length = longest_path(pair_edges, pair) + 1
262+
if record < path_length:
263+
record = path_length
264+
return record
265+

0 commit comments

Comments
 (0)