Skip to content

Commit 3716c29

Browse files
author
Edi Muskardin
committed
test PAPNI correctness on SEVPA characterizing set
1 parent 3458649 commit 3716c29

File tree

1 file changed

+54
-9
lines changed

1 file changed

+54
-9
lines changed

Benchmarking/passive_vpa_vs_rpni.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from random import shuffle
44

55
from aalpy import run_RPNI, run_PAPNI, AutomatonSUL
6-
from aalpy.utils import convert_i_o_traces_for_RPNI, generate_input_output_data_from_vpa
6+
from aalpy.utils import convert_i_o_traces_for_RPNI, generate_input_output_data_from_vpa, is_balanced
77
from aalpy.utils.BenchmarkVpaModels import get_all_VPAs
88
from statistics import mean, stdev
99

@@ -59,7 +59,7 @@ def evaluate_model(learned_model, test_data):
5959
return [rpni_model.size, papni_model.size, rpni_error, papni_error]
6060

6161

62-
def get_sequences_from_active_sevpa(model):
62+
def get_sequences_from_active_sevpa(model, verbose=False):
6363
from aalpy import SUL, run_KV, RandomWordEqOracle, SevpaAlphabet
6464

6565
class CustomSUL(SUL):
@@ -89,9 +89,9 @@ def step(self, letter):
8989
eq_oracle = RandomWordEqOracle(alphabet.get_merged_alphabet(), sul, num_walks=50000, min_walk_len=6,
9090
max_walk_len=30, reset_after_cex=True)
9191
# eq_oracle = BreadthFirstExplorationEqOracle(vpa_alphabet.get_merged_alphabet(), sul, 7)
92-
_ = run_KV(alphabet, sul, eq_oracle, automaton_type='vpa', print_level=3)
92+
lm = run_KV(alphabet, sul, eq_oracle, automaton_type='vpa', print_level=3 if verbose else 0)
9393

94-
return convert_i_o_traces_for_RPNI(sul.sequences)
94+
return convert_i_o_traces_for_RPNI(sul.sequences), lm
9595

9696

9797
def split_data_to_learning_and_testing(data, learning_to_test_ratio=0.5):
@@ -183,12 +183,12 @@ def run_all_experiments_experiments(test_models, learning_to_test_ratio):
183183
print(res_str)
184184

185185

186-
def run_experiments_multiple_times(test_models, num_times):
186+
def run_experiments_multiple_times(test_models, num_times, learning_to_test_ratio=0.5):
187187
all_results = defaultdict(list)
188188
for idx, gt in enumerate(test_models):
189189
for _ in range(num_times):
190190
r = run_experiment(idx, gt, num_of_learning_seq=10000, max_learning_seq_len=50,
191-
random_data_generation=False, learning_to_test_ratio=0.5)
191+
random_data_generation=False, learning_to_test_ratio=learning_to_test_ratio)
192192

193193
all_results[idx].append(r)
194194

@@ -208,8 +208,53 @@ def run_experiments_multiple_times(test_models, num_times):
208208
with open('papni_rpni_eval_results.pickle', 'wb') as handle:
209209
pickle.dump(all_results, handle, protocol=pickle.HIGHEST_PROTOCOL)
210210

211+
def test_papni_based_on_sevpa_dataset():
212+
all_models = get_all_VPAs()
211213

212-
all_models = get_all_VPAs()
213214

214-
run_all_experiments_experiments(all_models, learning_to_test_ratio=0.5)
215-
# run_experiments_multiple_times(all_models, 20)
215+
for idx, gt in enumerate(all_models):
216+
sevpa_papni_mismatch, papni_error, sevpa_error = 0, 0, 0
217+
218+
input_al = gt.get_input_alphabet()
219+
220+
sevpa_dataset, sevpa_model = get_sequences_from_active_sevpa(gt)
221+
sevpa_dataset_set = set(sevpa_dataset)
222+
223+
papni_model = run_PAPNI(sevpa_dataset, input_al)
224+
225+
balanced_counter = 0
226+
not_in_learning = 0
227+
in_learning = 0
228+
229+
for seq, label in all_data[idx]:
230+
231+
if is_balanced(seq, input_al):
232+
balanced_counter += 1
233+
234+
if (seq, label) not in sevpa_dataset_set:
235+
not_in_learning += 1
236+
else:
237+
in_learning += 1
238+
239+
sevpa_model.reset_to_initial()
240+
sevpa_output = sevpa_model.execute_sequence(sevpa_model.initial_state, seq)
241+
242+
papni_model.reset_to_initial()
243+
papni_output = papni_model.execute_sequence(papni_model.initial_state, seq)
244+
245+
if sevpa_output != papni_output:
246+
sevpa_papni_mismatch += 1
247+
if papni_output[-1] != label:
248+
papni_error += 1
249+
if sevpa_output[-1] != label:
250+
sevpa_error += 1
251+
252+
print('--------------------------------------')
253+
print(f'Model Index {idx}; # well-matched {balanced_counter}, # unique tests {not_in_learning}')
254+
print(f'Papni Error {papni_error}')
255+
print(f'Sevpa Error {sevpa_error}')
256+
print(f'Mismatch {sevpa_papni_mismatch}')
257+
258+
assert in_learning + not_in_learning == balanced_counter
259+
260+
test_papni_based_on_sevpa_dataset()

0 commit comments

Comments
 (0)