Skip to content

Commit 2877d2f

Browse files
authored
Merge pull request #86 from DES-Lab/gsm-dev
Improved GSM performance
2 parents 7dd63d3 + 05a23a8 commit 2877d2f

File tree

2 files changed

+33
-35
lines changed

2 files changed

+33
-35
lines changed

aalpy/learning_algs/general_passive/GeneralizedStateMerging.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Dict, Tuple, Callable, List, Optional
44

55
from aalpy.learning_algs.general_passive.GsmNode import GsmNode, OutputBehavior, TransitionBehavior, TransitionInfo, \
6-
OutputBehaviorRange, TransitionBehaviorRange, intersection_iterator, NodeOrders, unknown_output, detect_data_format
6+
OutputBehaviorRange, TransitionBehaviorRange, intersection_iterator, unknown_output, detect_data_format
77
from aalpy.learning_algs.general_passive.ScoreFunctionsGSM import ScoreCalculation, hoeffding_compatibility
88

99

@@ -69,9 +69,7 @@ def __init__(self, *,
6969
self.score_calc: ScoreCalculation = score_calc
7070

7171
if node_order is None:
72-
node_order = NodeOrders.Default
73-
if node_order is NodeOrders.NoCompare or node_order is NodeOrders.Default:
74-
self.node_order = node_order
72+
self.node_order = GsmNode.default_order
7573
else:
7674
self.node_order = functools.cmp_to_key(lambda a, b: -1 if node_order(a, b) else 1)
7775

@@ -119,25 +117,28 @@ def run(self, data, convert=True, instrumentation: Instrumentation=None, data_fo
119117

120118
partition_candidates: Dict[Tuple[GsmNode, GsmNode], Partitioning] = dict()
121119
while True:
120+
# sort states. states are always sorted using default order on original prefix
121+
if self.node_order is not GsmNode.default_order:
122+
red_states.sort(key=self.node_order)
123+
122124
# get blue states
123125
blue_states = []
124126
for r in red_states:
125-
for _, t in r.transition_iterator():
127+
for _, _, t in r.transition_iterator():
126128
c = t.target
127129
if c in red_states:
128130
continue
129131
blue_states.append(c)
130-
if self.consider_only_min_blue or not self.score_calc.has_score_function():
131-
blue_states = [min(blue_states, key=self.node_order)]
132+
if self.consider_only_min_blue and self.node_order is GsmNode.default_order:
133+
break
132134

133135
# no blue states left -> done
134136
if len(blue_states) == 0:
135137
break
136-
if self.node_order is not NodeOrders.NoCompare:
138+
if self.consider_only_min_blue: # does it make sense to check the score function here?
139+
blue_states = [min(blue_states, key=self.node_order)]
140+
if self.node_order is not GsmNode.default_order:
137141
blue_states.sort(key=self.node_order)
138-
# red states are always sorted using default order on original prefix
139-
if self.node_order is not NodeOrders.Default:
140-
red_states.sort(key=self.node_order)
141142

142143
# loop over blue states
143144
promotion = False
@@ -237,12 +238,11 @@ def update_partition(red_node: GsmNode, blue_node: Optional[GsmNode]) -> GsmNode
237238
return red_node
238239
else:
239240
def update_partition(red_node: GsmNode, blue_node: Optional[GsmNode]) -> GsmNode:
240-
if red_node not in partitioning.full_mapping:
241+
p = partitioning.full_mapping.get(red_node) # could check smaller .red_mapping?
242+
if p is None:
241243
p = red_node.shallow_copy()
242244
partitioning.full_mapping[red_node] = p
243245
partitioning.red_mapping[red_node] = p
244-
else:
245-
p = partitioning.full_mapping[red_node]
246246
if blue_node is not None:
247247
partitioning.full_mapping[blue_node] = p
248248
return p
@@ -269,12 +269,12 @@ def update_partition(red_node: GsmNode, blue_node: Optional[GsmNode]) -> GsmNode
269269

270270
# create implied merges for all common successors
271271
for in_sym, blue_transitions in blue.transitions.items():
272-
partition_transitions = partition.get_or_create_transitions(in_sym)
272+
partition_transitions = partition.transitions[in_sym]
273273
for out_sym, blue_transition in blue_transitions.items():
274274
partition_transition = partition_transitions.get(out_sym)
275275
# handle unknown output
276-
if partition_transition is None:
277-
if out_sym is unknown_output and len(partition_transitions) != 0:
276+
if partition_transition is None and len(partition_transitions) != 0:
277+
if out_sym is unknown_output:
278278
assert len(partition_transitions) == 1
279279
partition_transition = list(partition_transitions.values())[0]
280280
if unknown_output in partition_transitions:

aalpy/learning_algs/general_passive/GsmNode.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22
import math
33
import pathlib
4+
from collections import defaultdict
45
from functools import total_ordering
56
from typing import Dict, Any, List, Tuple, Iterable, Callable, Union, TypeVar, Iterator, Optional, Sequence
67
import pydot
@@ -125,7 +126,7 @@ class GsmNode:
125126

126127
def __init__(self, prefix_access_pair, predecessor: 'GsmNode' = None):
127128
# TODO try single dict
128-
self.transitions: Dict[Any, Dict[Any, TransitionInfo]] = dict()
129+
self.transitions: defaultdict[Any, Dict[Any, TransitionInfo]] = defaultdict(dict)
129130
self.predecessor: GsmNode = predecessor
130131
self.prefix_access_pair = prefix_access_pair
131132

@@ -187,17 +188,17 @@ def get_or_create_transitions(self, in_sym) -> Dict[Any, TransitionInfo]:
187188
self.transitions[in_sym] = t
188189
return t
189190

190-
def transition_iterator(self) -> Iterable[Tuple[Tuple[Any, Any], TransitionInfo]]:
191+
def transition_iterator(self) -> Iterable[Tuple[Any, Any, TransitionInfo]]:
191192
for in_sym, transitions in self.transitions.items():
192193
for out_sym, node in transitions.items():
193-
yield (in_sym, out_sym), node
194+
yield in_sym, out_sym, node
194195

195196
def shallow_copy(self) -> 'GsmNode':
196197
node = GsmNode(self.prefix_access_pair, self.predecessor)
197198
for in_sym, t in self.transitions.items():
198-
d = dict()
199-
for out_sym, v in t.items():
200-
d[out_sym] = copy(v)
199+
d = dict() # appears to be faster than dict comprehension
200+
for out_sym, ti in t.items():
201+
d[out_sym] = TransitionInfo(ti.target, ti.count, ti.original_target, ti.original_count)
201202
node.transitions[in_sym] = d
202203
return node
203204

@@ -219,7 +220,7 @@ def get_all_nodes(self) -> List['GsmNode']:
219220
result = [self]
220221
backing_set = {self}
221222
for state in result:
222-
for _, transition in state.transition_iterator():
223+
for _, _, transition in state.transition_iterator():
223224
child = transition.target
224225
if child not in backing_set:
225226
backing_set.add(child)
@@ -231,7 +232,7 @@ def is_tree(self):
231232
backing_set = {self}
232233
while len(q) != 0:
233234
current = q.pop(0)
234-
for _, transition in current.transition_iterator():
235+
for _, _, transition in current.transition_iterator():
235236
child = transition.target
236237
if child in backing_set:
237238
return False
@@ -324,7 +325,7 @@ def state_label(node: GsmNode):
324325
return f'{node.get_prefix_output()} {node.count()}'
325326
else:
326327
def state_label(node: GsmNode):
327-
return f'{sum(t.count for _, t in node.transition_iterator())}'
328+
return f'{sum(t.count for _, _, t in node.transition_iterator())}'
328329
if trans_label is None and "label" not in trans_props:
329330
if output_behavior == "moore":
330331
def trans_label(node: GsmNode, in_sym, out_sym):
@@ -379,7 +380,7 @@ def node_naming(node: GsmNode):
379380
def add_trace(self, trace: IOTrace):
380381
curr_node: GsmNode = self
381382
for in_sym, out_sym in trace:
382-
transitions = curr_node.get_or_create_transitions(in_sym)
383+
transitions = curr_node.transitions[in_sym]
383384
info = transitions.get(out_sym)
384385
if info is None:
385386
node = GsmNode((in_sym, out_sym), curr_node)
@@ -397,7 +398,7 @@ def add_labeled_sequence(self, example: IOExample):
397398

398399
# step through inputs and add transitions
399400
for in_sym in inputs:
400-
transitions = curr_node.get_or_create_transitions(in_sym)
401+
transitions = curr_node.transitions[in_sym]
401402
t_infos = list(transitions.values())
402403
if len(t_infos) == 0:
403404
node = GsmNode((in_sym, unknown_output), curr_node)
@@ -457,13 +458,13 @@ def deterministic_compatible(self, other: 'GsmNode'):
457458
for _, trans_self, trans_other in intersection_iterator(self.transitions, other.transitions):
458459
if unknown_output in trans_self or unknown_output in trans_other:
459460
continue
460-
if list(trans_self.keys()) != list(trans_other.keys()):
461+
if trans_self.keys() != trans_other.keys():
461462
return False
462463
return True
463464

464465
def is_moore(self):
465466
for node in self.get_all_nodes():
466-
for (in_sym, out_sym), transition in node.transition_iterator():
467+
for in_sym, out_sym, transition in node.transition_iterator():
467468
child_output = transition.target.get_prefix_output()
468469
if out_sym is not unknown_output and child_output != out_sym:
469470
return False
@@ -486,9 +487,6 @@ def local_log_likelihood_contribution(self):
486487
return llc
487488

488489
def count(self):
489-
return sum(trans.count for _, trans in self.transition_iterator())
490+
return sum(trans.count for _, _, trans in self.transition_iterator())
490491

491-
492-
class NodeOrders:
493-
NoCompare = lambda n: 0
494-
Default = functools.cmp_to_key(lambda a, b: -1 if a < b else 1)
492+
default_order = functools.cmp_to_key(lambda a, b: -1 if a < b else 1)

0 commit comments

Comments
 (0)