11import functools
22import math
33import pathlib
4+ from collections import defaultdict
45from functools import total_ordering
56from typing import Dict , Any , List , Tuple , Iterable , Callable , Union , TypeVar , Iterator , Optional , Sequence
67import pydot
@@ -125,7 +126,7 @@ class GsmNode:
125126
126127 def __init__ (self , prefix_access_pair , predecessor : 'GsmNode' = None ):
127128 # TODO try single dict
128- self .transitions : Dict [Any , Dict [Any , TransitionInfo ]] = dict ( )
129+ self .transitions : defaultdict [Any , Dict [Any , TransitionInfo ]] = defaultdict ( dict )
129130 self .predecessor : GsmNode = predecessor
130131 self .prefix_access_pair = prefix_access_pair
131132
@@ -187,17 +188,17 @@ def get_or_create_transitions(self, in_sym) -> Dict[Any, TransitionInfo]:
187188 self .transitions [in_sym ] = t
188189 return t
189190
190- def transition_iterator (self ) -> Iterable [Tuple [Tuple [ Any , Any ] , TransitionInfo ]]:
191+ def transition_iterator (self ) -> Iterable [Tuple [Any , Any , TransitionInfo ]]:
191192 for in_sym , transitions in self .transitions .items ():
192193 for out_sym , node in transitions .items ():
193- yield ( in_sym , out_sym ) , node
194+ yield in_sym , out_sym , node
194195
195196 def shallow_copy (self ) -> 'GsmNode' :
196197 node = GsmNode (self .prefix_access_pair , self .predecessor )
197198 for in_sym , t in self .transitions .items ():
198- d = dict ()
199- for out_sym , v in t .items ():
200- d [out_sym ] = copy ( v )
199+ d = dict () # appears to be faster than dict comprehension
200+ for out_sym , ti in t .items ():
201+ d [out_sym ] = TransitionInfo ( ti . target , ti . count , ti . original_target , ti . original_count )
201202 node .transitions [in_sym ] = d
202203 return node
203204
@@ -219,7 +220,7 @@ def get_all_nodes(self) -> List['GsmNode']:
219220 result = [self ]
220221 backing_set = {self }
221222 for state in result :
222- for _ , transition in state .transition_iterator ():
223+ for _ , _ , transition in state .transition_iterator ():
223224 child = transition .target
224225 if child not in backing_set :
225226 backing_set .add (child )
@@ -231,7 +232,7 @@ def is_tree(self):
231232 backing_set = {self }
232233 while len (q ) != 0 :
233234 current = q .pop (0 )
234- for _ , transition in current .transition_iterator ():
235+ for _ , _ , transition in current .transition_iterator ():
235236 child = transition .target
236237 if child in backing_set :
237238 return False
@@ -324,7 +325,7 @@ def state_label(node: GsmNode):
324325 return f'{ node .get_prefix_output ()} { node .count ()} '
325326 else :
326327 def state_label (node : GsmNode ):
327- return f'{ sum (t .count for _ , t in node .transition_iterator ())} '
328+ return f'{ sum (t .count for _ , _ , t in node .transition_iterator ())} '
328329 if trans_label is None and "label" not in trans_props :
329330 if output_behavior == "moore" :
330331 def trans_label (node : GsmNode , in_sym , out_sym ):
@@ -379,7 +380,7 @@ def node_naming(node: GsmNode):
379380 def add_trace (self , trace : IOTrace ):
380381 curr_node : GsmNode = self
381382 for in_sym , out_sym in trace :
382- transitions = curr_node .get_or_create_transitions ( in_sym )
383+ transitions = curr_node .transitions [ in_sym ]
383384 info = transitions .get (out_sym )
384385 if info is None :
385386 node = GsmNode ((in_sym , out_sym ), curr_node )
@@ -397,7 +398,7 @@ def add_labeled_sequence(self, example: IOExample):
397398
398399 # step through inputs and add transitions
399400 for in_sym in inputs :
400- transitions = curr_node .get_or_create_transitions ( in_sym )
401+ transitions = curr_node .transitions [ in_sym ]
401402 t_infos = list (transitions .values ())
402403 if len (t_infos ) == 0 :
403404 node = GsmNode ((in_sym , unknown_output ), curr_node )
@@ -457,13 +458,13 @@ def deterministic_compatible(self, other: 'GsmNode'):
457458 for _ , trans_self , trans_other in intersection_iterator (self .transitions , other .transitions ):
458459 if unknown_output in trans_self or unknown_output in trans_other :
459460 continue
460- if list ( trans_self .keys ()) != list ( trans_other .keys () ):
461+ if trans_self .keys () != trans_other .keys ():
461462 return False
462463 return True
463464
464465 def is_moore (self ):
465466 for node in self .get_all_nodes ():
466- for ( in_sym , out_sym ) , transition in node .transition_iterator ():
467+ for in_sym , out_sym , transition in node .transition_iterator ():
467468 child_output = transition .target .get_prefix_output ()
468469 if out_sym is not unknown_output and child_output != out_sym :
469470 return False
@@ -486,9 +487,6 @@ def local_log_likelihood_contribution(self):
486487 return llc
487488
488489 def count (self ):
489- return sum (trans .count for _ , trans in self .transition_iterator ())
490+ return sum (trans .count for _ , _ , trans in self .transition_iterator ())
490491
491-
492- class NodeOrders :
493- NoCompare = lambda n : 0
494- Default = functools .cmp_to_key (lambda a , b : - 1 if a < b else 1 )
492+ default_order = functools .cmp_to_key (lambda a , b : - 1 if a < b else 1 )
0 commit comments