33from random import shuffle
44
55from aalpy import run_RPNI , run_PAPNI , AutomatonSUL
6- from aalpy .utils import convert_i_o_traces_for_RPNI , generate_input_output_data_from_vpa
6+ from aalpy .utils import convert_i_o_traces_for_RPNI , generate_input_output_data_from_vpa , is_balanced
77from aalpy .utils .BenchmarkVpaModels import get_all_VPAs
88from statistics import mean , stdev
99
@@ -59,7 +59,7 @@ def evaluate_model(learned_model, test_data):
5959 return [rpni_model .size , papni_model .size , rpni_error , papni_error ]
6060
6161
62- def get_sequences_from_active_sevpa (model ):
62+ def get_sequences_from_active_sevpa (model , verbose = False ):
6363 from aalpy import SUL , run_KV , RandomWordEqOracle , SevpaAlphabet
6464
6565 class CustomSUL (SUL ):
@@ -89,9 +89,9 @@ def step(self, letter):
8989 eq_oracle = RandomWordEqOracle (alphabet .get_merged_alphabet (), sul , num_walks = 50000 , min_walk_len = 6 ,
9090 max_walk_len = 30 , reset_after_cex = True )
9191 # eq_oracle = BreadthFirstExplorationEqOracle(vpa_alphabet.get_merged_alphabet(), sul, 7)
92- _ = run_KV (alphabet , sul , eq_oracle , automaton_type = 'vpa' , print_level = 3 )
92+ lm = run_KV (alphabet , sul , eq_oracle , automaton_type = 'vpa' , print_level = 3 if verbose else 0 )
9393
94- return convert_i_o_traces_for_RPNI (sul .sequences )
94+ return convert_i_o_traces_for_RPNI (sul .sequences ), lm
9595
9696
9797def split_data_to_learning_and_testing (data , learning_to_test_ratio = 0.5 ):
@@ -183,12 +183,12 @@ def run_all_experiments_experiments(test_models, learning_to_test_ratio):
183183 print (res_str )
184184
185185
186- def run_experiments_multiple_times (test_models , num_times ):
186+ def run_experiments_multiple_times (test_models , num_times , learning_to_test_ratio = 0.5 ):
187187 all_results = defaultdict (list )
188188 for idx , gt in enumerate (test_models ):
189189 for _ in range (num_times ):
190190 r = run_experiment (idx , gt , num_of_learning_seq = 10000 , max_learning_seq_len = 50 ,
191- random_data_generation = False , learning_to_test_ratio = 0.5 )
191+ random_data_generation = False , learning_to_test_ratio = learning_to_test_ratio )
192192
193193 all_results [idx ].append (r )
194194
@@ -208,8 +208,53 @@ def run_experiments_multiple_times(test_models, num_times):
208208 with open ('papni_rpni_eval_results.pickle' , 'wb' ) as handle :
209209 pickle .dump (all_results , handle , protocol = pickle .HIGHEST_PROTOCOL )
210210
211+ def test_papni_based_on_sevpa_dataset ():
212+ all_models = get_all_VPAs ()
211213
212- all_models = get_all_VPAs ()
213214
214- run_all_experiments_experiments (all_models , learning_to_test_ratio = 0.5 )
215- # run_experiments_multiple_times(all_models, 20)
215+ for idx , gt in enumerate (all_models ):
216+ sevpa_papni_mismatch , papni_error , sevpa_error = 0 , 0 , 0
217+
218+ input_al = gt .get_input_alphabet ()
219+
220+ sevpa_dataset , sevpa_model = get_sequences_from_active_sevpa (gt )
221+ sevpa_dataset_set = set (sevpa_dataset )
222+
223+ papni_model = run_PAPNI (sevpa_dataset , input_al )
224+
225+ balanced_counter = 0
226+ not_in_learning = 0
227+ in_learning = 0
228+
229+ for seq , label in all_data [idx ]:
230+
231+ if is_balanced (seq , input_al ):
232+ balanced_counter += 1
233+
234+ if (seq , label ) not in sevpa_dataset_set :
235+ not_in_learning += 1
236+ else :
237+ in_learning += 1
238+
239+ sevpa_model .reset_to_initial ()
240+ sevpa_output = sevpa_model .execute_sequence (sevpa_model .initial_state , seq )
241+
242+ papni_model .reset_to_initial ()
243+ papni_output = papni_model .execute_sequence (papni_model .initial_state , seq )
244+
245+ if sevpa_output != papni_output :
246+ sevpa_papni_mismatch += 1
247+ if papni_output [- 1 ] != label :
248+ papni_error += 1
249+ if sevpa_output [- 1 ] != label :
250+ sevpa_error += 1
251+
252+ print ('--------------------------------------' )
253+ print (f'Model Index { idx } ; # well-matched { balanced_counter } , # unique tests { not_in_learning } ' )
254+ print (f'Papni Error { papni_error } ' )
255+ print (f'Sevpa Error { sevpa_error } ' )
256+ print (f'Mismatch { sevpa_papni_mismatch } ' )
257+
258+ assert in_learning + not_in_learning == balanced_counter
259+
260+ test_papni_based_on_sevpa_dataset ()
0 commit comments