11import functools
2- from bisect import insort
3- from typing import Dict , Tuple , Callable , List , Optional
42from collections import deque
3+ from typing import Dict , Tuple , Callable , List , Optional
54
65from aalpy .learning_algs .general_passive .Node import Node , OutputBehavior , TransitionBehavior , TransitionInfo , \
7- OutputBehaviorRange , TransitionBehaviorRange , intersection_iterator
6+ OutputBehaviorRange , TransitionBehaviorRange , intersection_iterator , NodeOrders , unknown_output , detect_data_format
87from aalpy .learning_algs .general_passive .ScoreFunctionsGSM import ScoreCalculation , hoeffding_compatibility
98
109
@@ -54,10 +53,10 @@ def __init__(self, *,
5453 depth_first = False ):
5554
5655 if output_behavior not in OutputBehaviorRange :
57- raise ValueError (f"invalid output behavior { output_behavior } " )
56+ raise ValueError (f"invalid output behavior { output_behavior } . should be in { OutputBehaviorRange } " )
5857 self .output_behavior : OutputBehavior = output_behavior
5958 if transition_behavior not in TransitionBehaviorRange :
60- raise ValueError (f"invalid transition behavior { transition_behavior } " )
59+ raise ValueError (f"invalid transition behavior { transition_behavior } . should be in { TransitionBehaviorRange } " )
6160 self .transition_behavior : TransitionBehavior = transition_behavior
6261
6362 if score_calc is None :
@@ -70,8 +69,11 @@ def __init__(self, *,
7069 self .score_calc : ScoreCalculation = score_calc
7170
7271 if node_order is None :
73- node_order = Node .__lt__
74- self .node_order = functools .cmp_to_key (lambda a , b : - 1 if node_order (a , b ) else 1 )
72+ node_order = NodeOrders .Default
73+ if node_order is NodeOrders .NoCompare or node_order is NodeOrders .Default :
74+ self .node_order = node_order
75+ else :
76+ self .node_order = functools .cmp_to_key (lambda a , b : - 1 if node_order (a , b ) else 1 )
7577
7678 self .pta_preprocessing = pta_preprocessing or (lambda x : x )
7779 self .postprocessing = postprocessing or (lambda x : x )
@@ -91,15 +93,16 @@ def compute_local_compatibility(self, a: Node, b: Node):
9193
9294 # TODO: make more generic by adding the option to use a different algorithm than red blue
9395 # for selecting potential merge candidates. Maybe using inheritance with abstract `run`.
94- def run (self , data , convert = True , instrumentation : Instrumentation = None ):
96+ def run (self , data , convert = True , instrumentation : Instrumentation = None , data_format = None ):
9597 if instrumentation is None :
9698 instrumentation = Instrumentation ()
9799 instrumentation .reset (self )
98100
99- if isinstance (data , Node ):
100- root = data
101- else :
102- root = Node .createPTA (data , self .output_behavior )
101+ if data_format is None :
102+ data_format = detect_data_format (data )
103+ if data_format == "examples" and self .transition_behavior != "deterministic" :
104+ raise ValueError ("learning from examples is not possible for nondeterministic systems" )
105+ root = Node .createPTA (data , self .output_behavior , data_format )
103106
104107 root = self .pta_preprocessing (root )
105108 instrumentation .pta_construction_done (root )
@@ -128,7 +131,11 @@ def run(self, data, convert=True, instrumentation: Instrumentation = None):
128131 # no blue states left -> done
129132 if len (blue_states ) == 0 :
130133 break
131- blue_states .sort (key = self .node_order )
134+ if self .node_order is not NodeOrders .NoCompare :
135+ blue_states .sort (key = self .node_order )
136+ # red states are always sorted using default order on original prefix
137+ if self .node_order is not NodeOrders .Default :
138+ red_states .sort (key = self .node_order )
132139
133140 # loop over blue states
134141 promotion = False
@@ -139,7 +146,6 @@ def run(self, data, convert=True, instrumentation: Instrumentation = None):
139146 # calculate partitions resulting from merges with red states if necessary
140147 current_candidates : Dict [Node , Partitioning ] = dict ()
141148 perfect_partitioning = None
142-
143149 red_state = None
144150 for red_state in red_states :
145151 partition = partition_candidates .get ((red_state , blue_state ))
@@ -149,16 +155,16 @@ def run(self, data, convert=True, instrumentation: Instrumentation = None):
149155 perfect_partitioning = partition
150156 break
151157 current_candidates [red_state ] = partition
152-
153158 assert red_state is not None
159+
154160 # partition with perfect score found: don't consider anything else
155161 if perfect_partitioning :
156162 partition_candidates = {(red_state , blue_state ): perfect_partitioning }
157163 break
158164
159165 # no merge candidates for this blue state -> promote
160166 if all (part .score is False for part in current_candidates .values ()):
161- insort ( red_states , blue_state , key = self . node_order )
167+ red_states . append ( blue_state )
162168 instrumentation .log_promote (blue_state )
163169 promotion = True
164170 break
@@ -176,10 +182,11 @@ def run(self, data, convert=True, instrumentation: Instrumentation = None):
176182 best_candidate = max (partition_candidates .values (), key = lambda part : part .score )
177183 for real_node , partition_node in best_candidate .red_mapping .items ():
178184 real_node .transitions = partition_node .transitions
185+ real_node .prefix_access_pair = partition_node .prefix_access_pair
179186 for access_pair , t_info in real_node .transition_iterator ():
180187 if t_info .target not in red_states :
181188 t_info .target .predecessor = real_node
182- t_info .target .prefix_access_pair = access_pair # not sure whether this is actually required
189+ # t_info.target.prefix_access_pair = access_pair # not sure whether this is actually required
183190 instrumentation .log_merge (best_candidate )
184191 # FUTURE: optimizations for compatibility tests where merges can be orthogonal
185192 # FUTURE: caching for aggregating compatibility tests
@@ -247,9 +254,13 @@ def update_partition(red_node: Node, blue_node: Optional[Node]) -> Node:
247254 blue_in_sym , blue_out_sym = blue .prefix_access_pair
248255 blue_parent .transitions [blue_in_sym ][blue_out_sym ].target = red
249256
257+ partition = update_partition (red , None )
258+ if self .output_behavior == "moore" :
259+ partition .resolve_unknown_prefix_output (blue_out_sym )
260+
261+ # loop over implied merges
250262 q : deque [Tuple [Node , Node ]] = deque ([(red , blue )])
251263 pop = q .pop if self .depth_first else q .popleft
252-
253264 while len (q ) != 0 :
254265 red , blue = pop ()
255266 partition = update_partition (red , blue )
@@ -258,10 +269,25 @@ def update_partition(red_node: Node, blue_node: Optional[Node]) -> Node:
258269 if self .compute_local_compatibility (partition , blue ) is False :
259270 return partitioning
260271
272+ # create implied merges for all common successors
261273 for in_sym , blue_transitions in blue .transitions .items ():
262274 partition_transitions = partition .get_or_create_transitions (in_sym )
263275 for out_sym , blue_transition in blue_transitions .items ():
264276 partition_transition = partition_transitions .get (out_sym )
277+ # handle unknown output
278+ if partition_transition is None :
279+ if out_sym is unknown_output and len (partition_transitions ) != 0 :
280+ assert len (partition_transitions ) == 1
281+ partition_transition = list (partition_transitions .values ())[0 ]
282+ if unknown_output in partition_transitions :
283+ assert len (partition_transitions ) == 1
284+ partition_transition = partition_transitions .pop (unknown_output )
285+ partition_transitions [out_sym ] = partition_transition
286+ # re-hook access pair
287+ succ_part = update_partition (partition_transition .target , None )
288+ if self .output_behavior == "moore" or succ_part .predecessor is red :
289+ succ_part .resolve_unknown_prefix_output (out_sym )
290+ # add pairs
265291 if partition_transition is not None :
266292 q .append ((partition_transition .target , blue_transition .target ))
267293 partition_transition .count += blue_transition .count
@@ -287,6 +313,7 @@ def run_GSM(data, *,
287313 depth_first = False ,
288314 instrumentation = None ,
289315 convert = True ,
316+ data_format = None ,
290317 ):
291318 """
292319 TODO
@@ -318,12 +345,14 @@ def run_GSM(data, *,
318345
319346 convert:
320347
348+ data_format:
349+
321350
322351 Returns:
323352
324353
325354 """
326- # instantiate the gsm
355+ # instantiate gsm
327356 gsm = GeneralizedStateMerging (
328357 output_behavior = output_behavior ,
329358 transition_behavior = transition_behavior ,
@@ -338,4 +367,4 @@ def run_GSM(data, *,
338367 )
339368
340369 # run the algorithm
341- return gsm .run (data = data , instrumentation = instrumentation , convert = convert )
370+ return gsm .run (data = data , instrumentation = instrumentation , convert = convert , data_format = data_format )
0 commit comments