2020TransitionBehaviorRange = ["deterministic" , "nondeterministic" , "stochastic" ]
2121
2222DataFormat = str
23- DataFormatRange = ["io_traces" , "labeled_sequences" , "tree" ]
23+ DataFormatRange = ["io_traces" , "labeled_sequences" , "traces" , " tree" ]
2424
2525IOPair = Tuple [Any , Any ]
2626IOTrace = Sequence [IOPair ]
@@ -52,6 +52,52 @@ def union_iterator(a: Dict[Key, Val], b: Dict[Key, Val], default: Val = None) ->
5252 yield key , a_val , b_val
5353
5454
55+ # TODO reuse in RPNI
56+ def detect_data_format (data , check_consistency = False , guess = False ):
57+ # The different data formats are
58+ # - "tree": a tree-shaped automaton provided as a GsmNode
59+ # - "io_traces": either
60+ # - Moore traces [[o, (i,o), (i,o), ...], ...]
61+ # - Mealy traces [[(i,o), (i,o), ...], ...]
62+ # - "labeled_sequences": [([i, i, ...], o), ...]
63+ # - "traces": [[o, o, ...], ...]
64+
65+ if isinstance (data , GsmNode ):
66+ if not data .is_tree ():
67+ raise ValueError ("provided automaton is not a tree" )
68+ return "tree"
69+
70+ accepted_types = (Tuple , List )
71+
72+ # mapping data formats to compatibility criteria
73+ check_dict = dict (
74+ io_traces = lambda obj : len (obj ) <= 1 or all (isinstance (o , accepted_types ) and len (o ) == 2 for o in obj [1 :]),
75+ labeled_sequences = lambda obj : len (obj ) == 2 and isinstance (obj [0 ], accepted_types ),
76+ )
77+ accept_dict = {k : True for k in check_dict }
78+
79+ if not isinstance (data , accepted_types ):
80+ raise ValueError ("wrong input format. expected tuple or list." )
81+ if len (data ) == 0 :
82+ return "io_traces"
83+
84+ accepted_formats = list (accept_dict .keys ())
85+ for data_point in data :
86+ if not isinstance (data_point , accepted_types ):
87+ raise ValueError ("wrong input format. expected tuple or list." )
88+ for k , check in check_dict .items ():
89+ accept_dict [k ] &= check (data_point )
90+ accepted_formats = [k for k , v in accept_dict .items () if v ]
91+ if len (accepted_formats ) == 1 and not check_consistency :
92+ return accepted_formats [0 ]
93+ if len (accepted_formats ) == 0 :
94+ return "traces" # default to traces
95+ #raise ValueError("invalid or inconsistent data. no options left")
96+ if len (accepted_formats ) != 1 and not guess :
97+ raise ValueError ("ambiguous data format. data format needs to be specified explicitly." )
98+ return accepted_formats [0 ]
99+
100+
55101# TODO maybe split this for maintainability (and perfomance?)
56102class TransitionInfo :
57103 __slots__ = ["target" , "count" , "original_target" , "original_count" ]
@@ -379,6 +425,8 @@ def add_labeled_sequence(self, example: IOExample):
379425
380426 @staticmethod
381427 def createPTA (data , output_behavior , data_format = None ) -> 'GsmNode' :
428+ if data_format is None :
429+ data_format = detect_data_format (data )
382430 if data_format not in DataFormatRange :
383431 raise ValueError (f"invalid data format { data_format } . should be in { DataFormatRange } " )
384432
@@ -388,12 +436,14 @@ def createPTA(data, output_behavior, data_format=None) -> 'GsmNode':
388436 if data_format == "labeled_sequences" :
389437 for example in data :
390438 root_node .add_labeled_sequence (example )
391- if data_format == "io_traces" :
439+ if data_format == "io_traces" or data_format == "traces" :
392440 if output_behavior == "moore" :
393441 initial_output = data [0 ][0 ]
394442 root_node .prefix_access_pair = (None , initial_output )
395443 data = (d [1 :] for d in data )
396444 for trace in data :
445+ if data_format == "traces" :
446+ trace = (("step" , t ) for t in trace )
397447 root_node .add_trace (trace )
398448 return root_node
399449
0 commit comments