Skip to content

Commit b23bfe8

Browse files
author
Edi Muškardin
authored
Merge pull request #84 from DES-Lab/gsm-dev
Support more data formats and auto-detect
2 parents fd4f9ae + 8e06325 commit b23bfe8

File tree

5 files changed

+106
-10
lines changed

5 files changed

+106
-10
lines changed

Examples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,12 +1184,12 @@ def passive_vpa_learning_on_all_benchmark_models():
11841184

11851185
def gsm_rpni():
11861186
from aalpy import load_automaton_from_file
1187-
from aalpy.utils.Sampling import get_io_traces, sample_with_length_limits
1187+
from aalpy.utils.Sampling import get_data_from_input_sequence, sample_with_length_limits
11881188
from aalpy.learning_algs.general_passive.GeneralizedStateMerging import run_GSM
11891189

11901190
automaton = load_automaton_from_file("DotModels/car_alarm.dot", "moore")
11911191
input_traces = sample_with_length_limits(automaton.get_input_alphabet(), 100, 20, 30)
1192-
traces = get_io_traces(automaton, input_traces)
1192+
traces = get_data_from_input_sequence(automaton, input_traces, "io_traces")
11931193

11941194
learned_model = run_GSM(traces, output_behavior="moore", transition_behavior="deterministic")
11951195
learned_model.visualize()

aalpy/learning_algs/general_passive/GeneralizedStateMerging.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Dict, Tuple, Callable, List, Optional
44

55
from aalpy.learning_algs.general_passive.GsmNode import GsmNode, OutputBehavior, TransitionBehavior, TransitionInfo, \
6-
OutputBehaviorRange, TransitionBehaviorRange, intersection_iterator, NodeOrders, unknown_output
6+
OutputBehaviorRange, TransitionBehaviorRange, intersection_iterator, NodeOrders, unknown_output, detect_data_format
77
from aalpy.learning_algs.general_passive.ScoreFunctionsGSM import ScoreCalculation, hoeffding_compatibility
88

99

@@ -93,13 +93,17 @@ def compute_local_compatibility(self, a: GsmNode, b: GsmNode):
9393

9494
# TODO: make more generic by adding the option to use a different algorithm than red blue
9595
# for selecting potential merge candidates. Maybe using inheritance with abstract `run`.
96-
def run(self, data, convert=True, instrumentation: Instrumentation=None, data_format="io_traces"):
96+
def run(self, data, convert=True, instrumentation: Instrumentation=None, data_format=None):
9797
if instrumentation is None:
9898
instrumentation = Instrumentation()
9999
instrumentation.reset(self)
100100

101+
if data_format is None:
102+
data_format = detect_data_format(data)
101103
if data_format == "labeled_sequences" and self.transition_behavior != "deterministic":
102104
raise ValueError("learning from labeled_sequences is not possible for nondeterministic systems")
105+
if data_format == "traces" and self.transition_behavior == "deterministic":
106+
print("learning deterministic systems from (output) traces only. this rarely makes sense. is `data_format` set correctly?")
103107
root = GsmNode.createPTA(data, self.output_behavior, data_format)
104108

105109
root = self.pta_preprocessing(root)
@@ -310,7 +314,7 @@ def run_GSM(data: list, *,
310314
depth_first=False,
311315
instrumentation=None,
312316
convert=True,
313-
data_format='io_traces',
317+
data_format=None,
314318
):
315319
"""
316320
Performs a state merging algorithm in the red-blue framework on provided data.

aalpy/learning_algs/general_passive/GsmNode.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
TransitionBehaviorRange = ["deterministic", "nondeterministic", "stochastic"]
2121

2222
DataFormat = str
23-
DataFormatRange = ["io_traces", "labeled_sequences", "tree"]
23+
DataFormatRange = ["io_traces", "labeled_sequences", "traces", "tree"]
2424

2525
IOPair = Tuple[Any, Any]
2626
IOTrace = Sequence[IOPair]
@@ -52,6 +52,52 @@ def union_iterator(a: Dict[Key, Val], b: Dict[Key, Val], default: Val = None) ->
5252
yield key, a_val, b_val
5353

5454

55+
# TODO reuse in RPNI
56+
def detect_data_format(data, check_consistency=False, guess=False):
57+
# The different data formats are
58+
# - "tree": a tree-shaped automaton provided as a GsmNode
59+
# - "io_traces": either
60+
# - Moore traces [[o, (i,o), (i,o), ...], ...]
61+
# - Mealy traces [[(i,o), (i,o), ...], ...]
62+
# - "labeled_sequences": [([i, i, ...], o), ...]
63+
# - "traces": [[o, o, ...], ...]
64+
65+
if isinstance(data, GsmNode):
66+
if not data.is_tree():
67+
raise ValueError("provided automaton is not a tree")
68+
return "tree"
69+
70+
accepted_types = (Tuple, List)
71+
72+
# mapping data formats to compatibility criteria
73+
check_dict = dict(
74+
io_traces=lambda obj: len(obj) <= 1 or all(isinstance(o, accepted_types) and len(o) == 2 for o in obj[1:]),
75+
labeled_sequences=lambda obj: len(obj) == 2 and isinstance(obj[0], accepted_types),
76+
)
77+
accept_dict = {k: True for k in check_dict}
78+
79+
if not isinstance(data, accepted_types):
80+
raise ValueError("wrong input format. expected tuple or list.")
81+
if len(data) == 0:
82+
return "io_traces"
83+
84+
accepted_formats = list(accept_dict.keys())
85+
for data_point in data:
86+
if not isinstance(data_point, accepted_types):
87+
raise ValueError("wrong input format. expected tuple or list.")
88+
for k, check in check_dict.items():
89+
accept_dict[k] &= check(data_point)
90+
accepted_formats = [k for k, v in accept_dict.items() if v]
91+
if len(accepted_formats) == 1 and not check_consistency:
92+
return accepted_formats[0]
93+
if len(accepted_formats) == 0:
94+
return "traces" # default to traces
95+
#raise ValueError("invalid or inconsistent data. no options left")
96+
if len(accepted_formats) != 1 and not guess:
97+
raise ValueError("ambiguous data format. data format needs to be specified explicitly.")
98+
return accepted_formats[0]
99+
100+
55101
# TODO maybe split this for maintainability (and perfomance?)
56102
class TransitionInfo:
57103
__slots__ = ["target", "count", "original_target", "original_count"]
@@ -379,6 +425,8 @@ def add_labeled_sequence(self, example: IOExample):
379425

380426
@staticmethod
381427
def createPTA(data, output_behavior, data_format=None) -> 'GsmNode':
428+
if data_format is None:
429+
data_format = detect_data_format(data)
382430
if data_format not in DataFormatRange:
383431
raise ValueError(f"invalid data format {data_format}. should be in {DataFormatRange}")
384432

@@ -388,12 +436,14 @@ def createPTA(data, output_behavior, data_format=None) -> 'GsmNode':
388436
if data_format == "labeled_sequences":
389437
for example in data:
390438
root_node.add_labeled_sequence(example)
391-
if data_format == "io_traces":
439+
if data_format == "io_traces" or data_format == "traces":
392440
if output_behavior == "moore":
393441
initial_output = data[0][0]
394442
root_node.prefix_access_pair = (None, initial_output)
395443
data = (d[1:] for d in data)
396444
for trace in data:
445+
if data_format == "traces":
446+
trace = (("step", t) for t in trace)
397447
root_node.add_trace(trace)
398448
return root_node
399449

aalpy/utils/HelperFunctions.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from itertools import product
44
from collections import defaultdict
55

6+
from aalpy import Mdp, MarkovChain, McState, MooreMachine, Dfa, DfaState
7+
68

79
def extend_set(list_to_extend: list, new_elements: list) -> list:
810
"""
@@ -409,9 +411,7 @@ def product_with_possible_empty_iterable(*iterables, repeat=1):
409411
return product(*non_empty_iterables, repeat=repeat)
410412

411413

412-
def dfa_from_moore(moore_model):
413-
from aalpy.automata import Dfa, DfaState
414-
414+
def dfa_from_moore(moore_model: MooreMachine) -> Dfa:
415415
dfa_state_map = dict()
416416
# define states
417417
for moore_state in moore_model.states:
@@ -430,3 +430,20 @@ def dfa_from_moore(moore_model):
430430

431431
initial_state = dfa_state_map[moore_model.initial_state.state_id]
432432
return Dfa(initial_state, list(dfa_state_map.values()))
433+
434+
def mc_from_mdp(mdp: Mdp, input_symbol=None) -> MarkovChain:
435+
alphabet = mdp.get_input_alphabet()
436+
if len(alphabet) != 1 and input_symbol is None:
437+
raise ValueError('Cannot convert MDP with several inputs to Markov chain.')
438+
input_symbol = input_symbol or alphabet[0]
439+
440+
state_map = {state.state_id: McState(state.state_id, state.output) for state in mdp.states}
441+
for state in mdp.states:
442+
mdp_transitions = state.transitions.get(input_symbol)
443+
if mdp_transitions is None:
444+
continue
445+
mc_transitions = [(state_map[mdp_target.state_id], prob) for mdp_target, prob in mdp_transitions]
446+
state_map[state.state_id].transitions = mc_transitions
447+
448+
initial_state = state_map[mdp.initial_state.state_id]
449+
return MarkovChain(initial_state, list(state_map.values()))

aalpy/utils/Sampling.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,31 @@ def get_io_traces(automaton: Automaton, input_traces: list) -> list:
1818
return traces
1919

2020

21+
def get_labeled_sequences(automaton: Automaton, input_traces: list) -> list:
22+
moore_automata = (MooreMachine, Dfa, NDMooreMachine, Mdp, MarkovChain)
23+
is_moore = isinstance(automaton, moore_automata)
24+
25+
data = []
26+
for input_trace in input_traces:
27+
if len(input_trace) == 0:
28+
if not is_moore:
29+
raise ValueError("tried to get label of empty sequence for Mealy automaton.")
30+
output = automaton.initial_state.output
31+
else:
32+
output = automaton.execute_sequence(automaton.initial_state, input_trace)[-1]
33+
data.append((input_trace, output))
34+
return data
35+
36+
37+
def get_data_from_input_sequence(automaton: Automaton, input_sequence: list, data_format: str = "io_sequences"):
38+
if data_format == "io_sequences":
39+
return get_io_traces(automaton, input_sequence)
40+
elif data_format == "labeled_sequences":
41+
return get_labeled_sequences(automaton, input_sequence)
42+
else:
43+
raise ValueError(f"invalid data_format {data_format}. must be 'io_sequences' or 'labeled_sequences'")
44+
45+
2146
def support_automaton_arg(require_transform):
2247
def decorator(f):
2348
@wraps(f)

0 commit comments

Comments
 (0)