Skip to content

Commit ea16830

Browse files
committed
Merge remote-tracking branch 'origin/master'
2 parents 177cc3d + 269cc9c commit ea16830

File tree

4 files changed

+19
-14
lines changed

4 files changed

+19
-14
lines changed

aalpy/learning_algs/general_passive/GeneralizedStateMerging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def log_promote(self, node: GsmNode):
3535
def log_merge(self, part: Partitioning):
3636
pass
3737

38-
def learning_done(self, root: GsmNode, red_states: List[GsmNode]):
38+
def learning_done(self, root: GsmNode):
3939
pass
4040

4141

@@ -192,7 +192,7 @@ def run(self, data, convert=True, instrumentation: Instrumentation=None, data_fo
192192
# FUTURE: caching for aggregating compatibility tests
193193
partition_candidates.clear()
194194

195-
instrumentation.learning_done(root, red_states)
195+
instrumentation.learning_done(root)
196196

197197
root = self.postprocessing(root)
198198
if convert:

aalpy/learning_algs/general_passive/GsmNode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ def detect_data_format(data, check_consistency=False, guess=False):
6464
# - "traces": [[o, o, ...], ...]
6565

6666
if isinstance(data, GsmNode):
67-
if not data.is_tree():
68-
raise ValueError("provided automaton is not a tree")
6967
return "tree"
7068

7169
accepted_types = (Tuple, List)
@@ -432,6 +430,8 @@ def createPTA(data, output_behavior, data_format=None) -> 'GsmNode':
432430
raise ValueError(f"invalid data format {data_format}. should be in {DataFormatRange}")
433431

434432
if data_format == "tree":
433+
if not data.is_tree():
434+
raise ValueError("provided automaton is not a tree")
435435
return data
436436
root_node = GsmNode((None, unknown_output), None)
437437
if data_format == "labeled_sequences":

aalpy/learning_algs/general_passive/Instrumentation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import time
1+
from time import perf_counter
22
from typing import Dict, Optional
33

44
from aalpy.learning_algs.general_passive.GeneralizedStateMerging import Instrumentation, Partitioning, \
@@ -29,18 +29,18 @@ def reset(self, gsm: GeneralizedStateMerging):
2929
self.nr_merged_states = 0
3030
self.nr_red_states = 0
3131

32-
self.previous_time = time.time()
32+
self.previous_time = perf_counter()
3333

3434
def pta_construction_done(self, root):
35-
print(f'PTA Construction Time: {round(time.time() - self.previous_time, 2)}')
35+
print(f'PTA Construction Time: {round(perf_counter() - self.previous_time, 2)} s')
3636
if 1 < self.lvl:
3737
states = root.get_all_nodes()
3838
leafs = [state for state in states if len(state.transitions.keys()) == 0]
3939
depth = [state.get_prefix_length() for state in leafs]
4040
self.pta_size = len(states)
4141
print(f'PTA has {len(states)} states leading to {len(leafs)} leafs')
4242
print(f'min / avg / max depth : {min(depth)} / {sum(depth) / len(depth)} / {max(depth)}')
43-
self.previous_time = time.time()
43+
self.previous_time = perf_counter()
4444

4545
def print_status(self):
4646
reset_char = "\33[2K\r"
@@ -60,9 +60,9 @@ def log_merge(self, part: Partitioning):
6060
self.nr_merged_states += 1
6161
self.print_status()
6262

63-
def learning_done(self, root, red_states):
64-
print(f'\nLearning Time: {round(time.time() - self.previous_time, 2)}')
65-
print(f'Learned {len(red_states)} state automaton via {self.nr_merged_states} merges.')
63+
def learning_done(self, root: GsmNode):
64+
print(f'\nLearning Time: {round(perf_counter() - self.previous_time, 2)} s')
65+
print(f'Learned {self.nr_red_states} state automaton via {self.nr_merged_states} merges.')
6666
if 2 < self.lvl:
6767
root.visualize("model", self.gsm.output_behavior)
6868

aalpy/oracles/StatePrefixEqOracle.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class StatePrefixEqOracle(Oracle):
1212
rand_walk_len exactly walk_per_state times during learning. Therefore excessive testing of initial states is
1313
avoided.
1414
"""
15-
def __init__(self, alphabet: list, sul: SUL, walks_per_state=10, walk_len=12, depth_first=False):
15+
def __init__(self, alphabet: list, sul: SUL, walks_per_state=10, walk_len=12, max_tests=None, depth_first=True):
1616
"""
1717
Args:
1818
@@ -24,18 +24,20 @@ def __init__(self, alphabet: list, sul: SUL, walks_per_state=10, walk_len=12, de
2424
2525
walk_len:length of random walk
2626
27-
depth_first:first explore newest states
27+
max_tests:number of maximum tests. If set to None, this parameter will be ignored.
28+
29+
depth_first:first explore the newest states
2830
"""
2931

3032
super().__init__(alphabet, sul)
3133
self.walks_per_state = walks_per_state
3234
self.steps_per_walk = walk_len
3335
self.depth_first = depth_first
36+
self.max_tests = max_tests
3437

3538
self.freq_dict = dict()
3639

3740
def find_cex(self, hypothesis):
38-
3941
states_to_cover = []
4042
for state in hypothesis.states:
4143
if state.prefix is None:
@@ -57,6 +59,9 @@ def find_cex(self, hypothesis):
5759

5860
self.reset_hyp_and_sul(hypothesis)
5961

62+
if self.max_tests and self.num_queries == self.max_tests:
63+
return None
64+
6065
prefix = state.prefix
6166
for p in prefix:
6267
hypothesis.step(p)

0 commit comments

Comments
 (0)