@@ -399,7 +399,7 @@ fun bdd_to_tables_iterative bdd_term groupings_term =
399399
400400
401401(* MTBDD to Rules:simple outout *)
402- fun mtbdd_to_rules1 bdd_term =
402+ fun mtbdd_to_rules_all_paths bdd_term =
403403let
404404 open pairSyntax listSyntax stringSyntax numSyntax;
405405
549549
550550
551551(* MTBDD to Rules with or combinations grouped by action *)
552- fun mtbdd_to_rules2 bdd_term =
552+ fun mtbdd_to_rules_grouped_by_action_simple bdd_term =
553553let
554554 open pairSyntax listSyntax stringSyntax numSyntax;
555555
734734end
735735
736736
737+ (* best policy output from BDD *)
738+ fun mtbdd_to_rules bdd_term =
739+ let
740+ open pairSyntax listSyntax stringSyntax numSyntax;
741+
742+ fun num_of_term t = Arbnum.toInt (dest_numeral t);
743+
744+ val (root_term, rest) = dest_pair bdd_term
745+ val (edges_term, labels_term) = dest_pair rest
746+
747+ val edges_list = fst (dest_list edges_term)
748+ val edges = map (fn edge =>
749+ let
750+ val (parent, children) = dest_pair edge
751+ val (left, right) = dest_pair children
752+ in
753+ (num_of_term parent, num_of_term left, num_of_term right)
754+ end ) edges_list
755+
756+ val labels_list = fst (dest_list labels_term)
757+ val labels = map (fn label =>
758+ let val (id, data) = dest_pair label
759+ in (num_of_term id, data) end ) labels_list
760+
761+ fun get_var_name node_data =
762+ let
763+ val (const, args) = strip_comb node_data
764+ val const_name = #Name (dest_thy_const const)
765+ in
766+ if const_name = " non_termn" then
767+ let
768+ val arg = hd args
769+ val (some_part, _) = dest_pair arg
770+ val (_, some_args) = strip_comb some_part
771+ val str_term = hd some_args
772+ in
773+ stringSyntax.fromHOLstring str_term
774+ end
775+ else
776+ raise Fail (" Not a non-terminal node: " ^ const_name)
777+ end
778+
779+ fun get_action node_data =
780+ let
781+ val (const, args) = strip_comb node_data
782+ val const_name = #Name (dest_thy_const const)
783+ in
784+ if const_name = " termn" then
785+ let
786+ val first_arg = hd args
787+ val (actual_action, _) = dest_pair first_arg
788+ handle _ => (first_arg, ``()``)
789+ in
790+ actual_action
791+ end
792+ else
793+ raise Fail (" Not a terminal node: " ^ const_name)
794+ end
795+
796+ fun node_type node_data =
797+ let
798+ val (const, _) = strip_comb node_data
799+ val const_name = #Name (dest_thy_const const)
800+ in
801+ const_name
802+ end
803+
804+ fun is_terminal node_data = node_type node_data = " termn"
805+
806+ val root_id = num_of_term root_term
807+
808+ (* Find all terminals in left-to-right order, remove duplicates *)
809+ fun dfs_collect node_id =
810+ let
811+ val node_data = case List.find (fn (id, _) => id = node_id) labels of
812+ SOME (_, data) => data
813+ | NONE => raise Fail (" Node not found: " ^ Int.toString node_id)
814+ in
815+ if is_terminal node_data then
816+ [node_id]
817+ else
818+ let
819+ val (left_child, right_child) =
820+ case List.find (fn (src, l, r) => src = node_id) edges of
821+ SOME (_, l, r) => (l, r)
822+ | NONE => raise Fail (" No edges from node " ^ Int.toString node_id)
823+ in
824+ dfs_collect left_child @ dfs_collect right_child
825+ end
826+ end
827+
828+ val terminal_ids = dfs_collect root_id
829+ val unique_terminals =
830+ let
831+ val seen = ref []
832+ fun keep_unique [] = []
833+ | keep_unique (x::xs) =
834+ if List.exists (fn y => y = x) (!seen) then keep_unique xs
835+ else (seen := x :: !seen; x :: keep_unique xs)
836+ in
837+ keep_unique terminal_ids
838+ end
839+
840+ (* Get left child of a node *)
841+ fun get_left_child node_id =
842+ case List.find (fn (src, l, r) => src = node_id) edges of
843+ SOME (_, left, _) => left
844+ | NONE => raise Fail (" No edges from node " ^ Int.toString node_id)
845+
846+ (* Get all parents of a node *)
847+ fun get_parents node_id =
848+ List.map (fn (src, _, _) => src)
849+ (List.filter (fn (_, left, right) => left = node_id orelse right = node_id) edges)
850+
851+ (* Check if a node can be a rule starting point *)
852+ fun can_be_rule_start node_id =
853+ if node_id = root_id then
854+ true (* Root can always start a rule *)
855+ else
856+ let
857+ (* Condition 1: All parents reach this node via RIGHT branch *)
858+ val incoming_edges = List.filter (fn (_, left, right) =>
859+ left = node_id orelse right = node_id) edges
860+
861+ val all_via_right = List.all (fn (_, left, right) => right = node_id) incoming_edges
862+
863+ (* Condition 2: Node's left child has only this node as parent (via left branch) *)
864+ val left_child = get_left_child node_id
865+ val left_child_parents = get_parents left_child
866+
867+ (* Check if left child has exactly one parent AND it's this node via left branch *)
868+ val left_child_has_single_parent =
869+ length left_child_parents = 1 andalso
870+ hd left_child_parents = node_id andalso
871+ List.exists (fn (src, left, _) => src = node_id andalso left = left_child) edges
872+ in
873+ all_via_right andalso left_child_has_single_parent
874+ end
875+
876+ (* Find starting nodes for rules - nodes that can start rules *)
877+ fun find_rule_starting_nodes target_terminal =
878+ let
879+ (* Find all nodes that point to target_terminal via LEFT branch *)
880+ val left_parents = List.map (fn (src, left, _) => src)
881+ (List.filter (fn (_, left, _) => left = target_terminal) edges)
882+
883+ (* For each left parent, find the closest ancestor that can start a rule *)
884+ fun find_starting_node_for_parent parent_id =
885+ let
886+ fun trace_back node_id =
887+ if can_be_rule_start node_id then
888+ node_id (* Found a valid starting point *)
889+ else if node_id = root_id then
890+ root_id (* Reached root, use it as starting point *)
891+ else
892+ (* Find parent and continue *)
893+ let
894+ val parents = get_parents node_id
895+ in
896+ case parents of
897+ [parent] => trace_back parent
898+ | _ => raise Fail (" Multiple or no parents for node " ^ Int.toString node_id)
899+ end
900+ in
901+ trace_back parent_id
902+ end
903+
904+ (* Also check if terminal is directly at root *)
905+ val root_start =
906+ if target_terminal = root_id then
907+ [root_id]
908+ else if List.exists (fn (src, left, _) => src = root_id andalso left = target_terminal) edges then
909+ [root_id]
910+ else []
911+ in
912+ (* Remove duplicates from starting nodes *)
913+ let
914+ val all_starts = root_start @ (map find_starting_node_for_parent left_parents)
915+ fun remove_dups [] = []
916+ | remove_dups (x::xs) =
917+ if List.exists (fn y => y = x) xs then remove_dups xs
918+ else x::remove_dups xs
919+ in
920+ remove_dups all_starts
921+ end
922+ end
923+
924+ (* Find all paths from a starting node to target terminal *)
925+ fun find_paths_from_start start_id target_id =
926+ let
927+ fun dfs current_id current_path =
928+ if current_id = target_id then
929+ [List.rev current_path] (* Found target *)
930+ else
931+ let
932+ val node_data = case List.find (fn (id, _) => id = current_id) labels of
933+ SOME (_, data) => data
934+ | NONE => raise Fail (" Node not found" )
935+ in
936+ if is_terminal node_data then
937+ [] (* Different terminal *)
938+ else
939+ let
940+ val (left_child, right_child) =
941+ case List.find (fn (src, l, r) => src = current_id) edges of
942+ SOME (_, l, r) => (l, r)
943+ | NONE => raise Fail (" No edges from node" )
944+
945+ val var_name = get_var_name node_data
946+
947+ (* Try left branch *)
948+ val left_paths = dfs left_child ((current_id, var_name, true )::current_path)
949+
950+ (* Try right branch *)
951+ val right_paths = dfs right_child ((current_id, var_name, false )::current_path)
952+ in
953+ left_paths @ right_paths
954+ end
955+ end
956+ in
957+ dfs start_id []
958+ end
959+
960+ (* Build rule for a terminal (except last one) *)
961+ fun build_rule_for_terminal term_id =
962+ let
963+ val term_data = case List.find (fn (id, _) => id = term_id) labels of
964+ SOME (_, data) => data
965+ | NONE => raise Fail (" Terminal not found" )
966+ val action = get_action term_data
967+
968+ (* Find starting nodes for rules *)
969+ val starting_nodes = find_rule_starting_nodes term_id
970+
971+ (* For each starting node, find paths to terminal *)
972+ val all_paths =
973+ List.concat (map (fn start_id => find_paths_from_start start_id term_id) starting_nodes)
974+
975+ (* Convert a path to predicate: AND of variables where we took LEFT branch *)
976+ fun path_to_predicate path =
977+ let
978+ val positive_vars = List.map (fn (_, var_name, _) => var_name)
979+ (List.filter (fn (_, _, decision) => decision) path)
980+
981+ fun build_and [] = ``(True : pred)``
982+ | build_and [var] = ``(Var ^(stringSyntax.fromMLstring var)) : pred``
983+ | build_and (var::vars) =
984+ let
985+ val first_pred = ``(Var ^(stringSyntax.fromMLstring var)) : pred``
986+ val rest_pred = build_and vars
987+ in
988+ if aconv rest_pred ``(True : pred)`` then
989+ first_pred
990+ else
991+ ``(And ^first_pred ^rest_pred) : pred``
992+ end
993+ in
994+ build_and positive_vars
995+ end
996+
997+ (* Convert all paths to predicates *)
998+ val predicates = map path_to_predicate all_paths
999+
1000+ (* Remove duplicate predicates *)
1001+ val unique_predicates =
1002+ let
1003+ fun remove_dups [] = []
1004+ | remove_dups (p::ps) =
1005+ if List.exists (fn q => aconv p q) ps then remove_dups ps
1006+ else p::remove_dups ps
1007+ in
1008+ remove_dups predicates
1009+ end
1010+
1011+ (* Combine unique predicates with OR *)
1012+ val final_predicate =
1013+ case unique_predicates of
1014+ [] => ``(True : pred)``
1015+ | [p] => p
1016+ | p::ps => List.foldl (fn (pred, acc) => ``(Or ^acc ^pred) : pred``) p ps
1017+ in
1018+ (final_predicate, action)
1019+ end
1020+
1021+ (* Build all rules - last terminal gets True *)
1022+ fun build_rules [] = []
1023+ | build_rules terminals =
1024+ let
1025+ val num_terms = length terminals
1026+ fun build idx remaining =
1027+ case remaining of
1028+ [] => []
1029+ | [term_id] => (* Last terminal gets True *)
1030+ let
1031+ val term_data = case List.find (fn (id, _) => id = term_id) labels of
1032+ SOME (_, data) => data
1033+ | NONE => raise Fail (" Terminal not found" )
1034+ val action = get_action term_data
1035+ in
1036+ [(``(True : pred)``, action)]
1037+ end
1038+ | term_id::rest =>
1039+ let
1040+ val rule = build_rule_for_terminal term_id
1041+ in
1042+ rule :: build (idx+1 ) rest
1043+ end
1044+ in
1045+ build 0 terminals
1046+ end
1047+
1048+ val rules = build_rules unique_terminals
1049+
1050+ val rule_terms =
1051+ if null rules then
1052+ listSyntax.mk_list ([], ``:pred # action``)
1053+ else
1054+ let
1055+ val rule_pairs = map (fn (pred, act) => mk_pair (pred, act)) rules
1056+ val pair_type = type_of (hd rule_pairs)
1057+ in
1058+ listSyntax.mk_list (rule_pairs, pair_type)
1059+ end
1060+
1061+ in
1062+ rule_terms
1063+ end
1064+
1065+
1066+
7371067end
7381068
7391069(*
0 commit comments