Skip to content

Commit 7fcd01f

Browse files
author
Edi Muskardin
committed
update rpni vs papni comparison script
1 parent bb40ce9 commit 7fcd01f

File tree

4 files changed

+275
-153
lines changed

4 files changed

+275
-153
lines changed
Lines changed: 109 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from aalpy import run_RPNI, run_PAPNI, load_automaton_from_file
1+
from aalpy import run_RPNI, run_PAPNI, AutomatonSUL
22
from aalpy.utils import convert_i_o_traces_for_RPNI, generate_input_output_data_from_vpa
33
from aalpy.utils.BenchmarkVpaModels import get_all_VPAs
44

@@ -17,29 +17,21 @@ def calculate_precision_recall_f1(true_positives, false_positives, false_negativ
1717
return precision, recall, f1
1818

1919

20-
def compare_rpni_and_papni(original_model, rpni_model, papni_model, num_sequances, min_seq_len, max_seq_len):
21-
test_data = generate_input_output_data_from_vpa(original_model, num_sequances, min_seq_len, max_seq_len)
22-
test_data = convert_i_o_traces_for_RPNI(test_data)
23-
24-
def calculate_f1_score(precision, recall):
25-
if precision + recall == 0:
26-
return 0
27-
return 2 * (precision * recall) / (precision + recall)
28-
20+
def compare_rpni_and_papni(test_data, rpni_model, papni_model):
2921
def evaluate_model(learned_model, test_data):
3022
true_positives = 0
3123
false_positives = 0
3224
false_negatives = 0
3325

34-
for input_seq, out in test_data:
26+
for input_seq, correct_output in test_data:
3527
learned_model.reset_to_initial()
3628
learned_output = learned_model.execute_sequence(learned_model.initial_state, input_seq)[-1]
3729

38-
if learned_output and out:
30+
if learned_output and correct_output:
3931
true_positives += 1
40-
elif learned_output and not out:
32+
elif learned_output and not correct_output:
4133
false_positives += 1
42-
elif not learned_output and out:
34+
elif not learned_output and correct_output:
4335
false_negatives += 1
4436

4537
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
@@ -51,26 +43,114 @@ def evaluate_model(learned_model, test_data):
5143
rpni_error = evaluate_model(rpni_model, test_data)
5244
papni_error = evaluate_model(papni_model, test_data)
5345

54-
print(f'-----------------------------------------------------------------')
55-
print(f'RPNI size {rpni_model.size} vs {papni_model.size} PAPNI size')
56-
print(f'RPNI precision, recall, f1: {rpni_error}')
57-
print(f'PAPNI precision, recall, f1: {papni_error}')
46+
# print(f'RPNI size {rpni_model.size} vs {papni_model.size} PAPNI size')
47+
# print(f'RPNI precision, recall, f1: {rpni_error}')
48+
# print(f'PAPNI precision, recall, f1: {papni_error}')
49+
50+
return [rpni_model.size, papni_model.size, rpni_error, papni_error]
51+
52+
53+
def get_sequances_from_active_sevpa(model):
54+
from aalpy import SUL, run_KV, RandomWordEqOracle, SevpaAlphabet
55+
56+
class CustomSUL(SUL):
57+
def __init__(self, automatonSUL):
58+
super(CustomSUL, self).__init__()
59+
self.sul = automatonSUL
60+
self.sequances = []
61+
62+
def pre(self):
63+
self.tc = []
64+
self.sul.pre()
65+
66+
def post(self):
67+
self.sequances.append(self.tc)
68+
self.sul.post()
69+
70+
def step(self, letter):
71+
output = self.sul.step(letter)
72+
if letter is not None:
73+
self.tc.append((letter, output))
74+
return output
75+
76+
vpa_alphabet = model.get_input_alphabet()
77+
alphabet = SevpaAlphabet(vpa_alphabet.internal_alphabet, vpa_alphabet.call_alphabet, vpa_alphabet.return_alphabet)
78+
sul = AutomatonSUL(model)
79+
sul = CustomSUL(sul)
80+
eq_oracle = RandomWordEqOracle(alphabet.get_merged_alphabet(), sul, num_walks=50000, min_walk_len=6,
81+
max_walk_len=18, reset_after_cex=False)
82+
# eq_oracle = BreadthFirstExplorationEqOracle(vpa_alphabet.get_merged_alphabet(), sul, 7)
83+
_ = run_KV(alphabet, sul, eq_oracle, automaton_type='vpa', print_level=3)
84+
85+
return convert_i_o_traces_for_RPNI(sul.sequances)
86+
87+
88+
def split_data_to_learning_and_testing(data, learning_to_test_ratio=0.5):
89+
total_number_positive = len([x for x in data if x[1]])
90+
total_number_negative = len(data) - total_number_positive
91+
92+
num_learning_positive_seq = total_number_positive * learning_to_test_ratio
93+
num_learning_negative_seq = total_number_negative * learning_to_test_ratio
94+
95+
sorted(data, key=lambda x: len(x[0]))
96+
97+
learning_sequances, test_sequances = [], []
98+
99+
l_pos, l_neg = 0, 0
100+
for seq, label in data:
101+
if label and l_pos <= num_learning_positive_seq:
102+
learning_sequances.append((seq, label))
103+
l_pos += 1
104+
elif not label and l_neg <= num_learning_negative_seq:
105+
learning_sequances.append((seq, label))
106+
l_neg += 1
107+
else:
108+
test_sequances.append((seq, label))
109+
110+
return learning_sequances, test_sequances
111+
112+
113+
def run_experiment(ground_truth_model,
114+
num_of_learning_seq,
115+
max_learning_seq_len,
116+
random_data_generation=True):
117+
if random_data_generation:
118+
data = generate_input_output_data_from_vpa(ground_truth_model,
119+
num_sequances=num_of_learning_seq,
120+
max_seq_len=max_learning_seq_len)
121+
else:
122+
data = get_sequances_from_active_sevpa(ground_truth_model)
123+
124+
vpa_alphabet = ground_truth_model.get_input_alphabet()
125+
126+
learning_data, test_data = split_data_to_learning_and_testing(data, learning_to_test_ratio=0.5)
127+
128+
num_positive_learning = len([x for x in learning_data if x[1]])
129+
learning_set_size = (num_positive_learning, len(learning_data) - num_positive_learning)
130+
131+
num_positive_test = len([x for x in test_data if x[1]])
132+
test_set_size = (num_positive_test, len(test_data) - num_positive_test)
133+
134+
rpni_model = run_RPNI(learning_data, 'dfa', print_info=False, input_completeness='sink_state')
135+
136+
papni_model = run_PAPNI(learning_data, vpa_alphabet, print_info=False)
137+
138+
comparison_results = compare_rpni_and_papni(test_data, rpni_model, papni_model)
58139

140+
comparison_results = comparison_results + [learning_set_size, test_set_size]
141+
return comparison_results
59142

60-
arithmetics_model = load_automaton_from_file('../DotModels/arithmetics.dot', 'vpa')
61143

62-
# 15 test benchmarks
63-
test_models = [arithmetics_model]
64-
test_models.extend(get_all_VPAs())
144+
def run_all_experiments_experiments(test_models):
145+
for idx, gt in enumerate(test_models):
146+
results = run_experiment(gt, num_of_learning_seq=10000, max_learning_seq_len=50, random_data_generation=True)
65147

66-
for ground_truth in test_models:
67-
vpa_alphabet = ground_truth.get_input_alphabet()
148+
res_str = f'GT {idx + 1}:\t Learning ({results[-2][0]}/{results[-2][1]}),\t Test ({results[-1][0]}/{results[-1][1]}),\t'
149+
res_str += f'RPNI: size: {results[0]}, prec/rec/F1: {results[2]}, \t PAPNI size: {results[1]}, prec/rec/F1: {results[3]}'
68150

69-
data = generate_input_output_data_from_vpa(ground_truth, num_sequances=2000, min_seq_len=1, max_seq_len=12)
70-
data = convert_i_o_traces_for_RPNI(data)
151+
print(res_str)
71152

72-
rpni_model = run_RPNI(data, 'dfa', print_info=True, input_completeness='sink_state')
73153

74-
papni_model = run_PAPNI(data, vpa_alphabet, print_info=True)
154+
all_models = get_all_VPAs()
75155

76-
compare_rpni_and_papni(ground_truth, rpni_model, papni_model, num_sequances=100, min_seq_len=20, max_seq_len=40)
156+
run_all_experiments_experiments(all_models)

Benchmarking/vpa_benchmarking/benchmark_vpa.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pickle
66

7-
from aalpy.SULs.AutomataSUL import SevpaSUL, VpaSUL, DfaSUL
7+
from aalpy.SULs.AutomataSUL import SevpaSUL, DfaSUL
88
from aalpy.automata import SevpaAlphabet
99
from aalpy.learning_algs import run_KV
1010
from aalpy.oracles import RandomWordEqOracle
@@ -196,8 +196,8 @@ def benchmark_vpa_dfa():
196196
label_data = []
197197

198198
for i, vpa in enumerate(
199-
[vpa_for_L1(), vpa_for_L2(), vpa_for_L3(), vpa_for_L4(), vpa_for_L5(), vpa_for_L7(), vpa_for_L8(),
200-
vpa_for_L9(), vpa_for_L10(), vpa_for_L11(), vpa_for_L12(), vpa_for_L13(), vpa_for_L14(), vpa_for_L15()]):
199+
[vpa_L1(), vpa_L2(), vpa_for_L3(), vpa_L3(), vpa_for_L5(), vpa_L4(), vpa_L6(),
200+
vpa_L8(), vpa_for_L10(), vpa_for_L11(), vpa_L9(), vpa_L10(), vpa_L11(), vpa_L12()]):
201201
print(f'VPA {i + 1 if i < 6 else i + 2}')
202202
label_data.append(f'VPA {i + 1 if i < 6 else i + 2}')
203203

0 commit comments

Comments
 (0)