Skip to content

Commit 47c0d2b

Browse files
added best version of BDD to policies in SML
1 parent c88ec3b commit 47c0d2b

File tree

2 files changed

+336
-5
lines changed

2 files changed

+336
-5
lines changed

hol/policy_to_table/bdd_utilsLib.sig

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ sig
55
val make_bv :int -> int -> term
66
val pairBDDs : term * term -> term
77
val bdd_to_tables_iterative : term -> term -> term
8-
val mtbdd_to_rules1 : term -> term
9-
val mtbdd_to_rules2 : term -> term
10-
8+
val mtbdd_to_rules_all_paths : term -> term
9+
val mtbdd_to_rules_grouped_by_action_simple : term -> term
10+
val mtbdd_to_rules : term -> term
11+
1112
end

hol/policy_to_table/bdd_utilsLib.sml

Lines changed: 332 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ fun bdd_to_tables_iterative bdd_term groupings_term =
399399

400400

401401
(* MTBDD to Rules:simple outout *)
402-
fun mtbdd_to_rules1 bdd_term =
402+
fun mtbdd_to_rules_all_paths bdd_term =
403403
let
404404
open pairSyntax listSyntax stringSyntax numSyntax;
405405

@@ -549,7 +549,7 @@ end
549549

550550

551551
(* MTBDD to Rules with or combinations grouped by action *)
552-
fun mtbdd_to_rules2 bdd_term =
552+
fun mtbdd_to_rules_grouped_by_action_simple bdd_term =
553553
let
554554
open pairSyntax listSyntax stringSyntax numSyntax;
555555

@@ -734,6 +734,336 @@ in
734734
end
735735

736736

737+
(* best policy output from BDD *)
738+
fun mtbdd_to_rules bdd_term =
739+
let
740+
open pairSyntax listSyntax stringSyntax numSyntax;
741+
742+
fun num_of_term t = Arbnum.toInt (dest_numeral t);
743+
744+
val (root_term, rest) = dest_pair bdd_term
745+
val (edges_term, labels_term) = dest_pair rest
746+
747+
val edges_list = fst (dest_list edges_term)
748+
val edges = map (fn edge =>
749+
let
750+
val (parent, children) = dest_pair edge
751+
val (left, right) = dest_pair children
752+
in
753+
(num_of_term parent, num_of_term left, num_of_term right)
754+
end) edges_list
755+
756+
val labels_list = fst (dest_list labels_term)
757+
val labels = map (fn label =>
758+
let val (id, data) = dest_pair label
759+
in (num_of_term id, data) end) labels_list
760+
761+
fun get_var_name node_data =
762+
let
763+
val (const, args) = strip_comb node_data
764+
val const_name = #Name (dest_thy_const const)
765+
in
766+
if const_name = "non_termn" then
767+
let
768+
val arg = hd args
769+
val (some_part, _) = dest_pair arg
770+
val (_, some_args) = strip_comb some_part
771+
val str_term = hd some_args
772+
in
773+
stringSyntax.fromHOLstring str_term
774+
end
775+
else
776+
raise Fail ("Not a non-terminal node: " ^ const_name)
777+
end
778+
779+
fun get_action node_data =
780+
let
781+
val (const, args) = strip_comb node_data
782+
val const_name = #Name (dest_thy_const const)
783+
in
784+
if const_name = "termn" then
785+
let
786+
val first_arg = hd args
787+
val (actual_action, _) = dest_pair first_arg
788+
handle _ => (first_arg, ``()``)
789+
in
790+
actual_action
791+
end
792+
else
793+
raise Fail ("Not a terminal node: " ^ const_name)
794+
end
795+
796+
fun node_type node_data =
797+
let
798+
val (const, _) = strip_comb node_data
799+
val const_name = #Name (dest_thy_const const)
800+
in
801+
const_name
802+
end
803+
804+
fun is_terminal node_data = node_type node_data = "termn"
805+
806+
val root_id = num_of_term root_term
807+
808+
(* Find all terminals in left-to-right order, remove duplicates *)
809+
fun dfs_collect node_id =
810+
let
811+
val node_data = case List.find (fn (id, _) => id = node_id) labels of
812+
SOME (_, data) => data
813+
| NONE => raise Fail ("Node not found: " ^ Int.toString node_id)
814+
in
815+
if is_terminal node_data then
816+
[node_id]
817+
else
818+
let
819+
val (left_child, right_child) =
820+
case List.find (fn (src, l, r) => src = node_id) edges of
821+
SOME (_, l, r) => (l, r)
822+
| NONE => raise Fail ("No edges from node " ^ Int.toString node_id)
823+
in
824+
dfs_collect left_child @ dfs_collect right_child
825+
end
826+
end
827+
828+
val terminal_ids = dfs_collect root_id
829+
val unique_terminals =
830+
let
831+
val seen = ref []
832+
fun keep_unique [] = []
833+
| keep_unique (x::xs) =
834+
if List.exists (fn y => y = x) (!seen) then keep_unique xs
835+
else (seen := x :: !seen; x :: keep_unique xs)
836+
in
837+
keep_unique terminal_ids
838+
end
839+
840+
(* Get left child of a node *)
841+
fun get_left_child node_id =
842+
case List.find (fn (src, l, r) => src = node_id) edges of
843+
SOME (_, left, _) => left
844+
| NONE => raise Fail ("No edges from node " ^ Int.toString node_id)
845+
846+
(* Get all parents of a node *)
847+
fun get_parents node_id =
848+
List.map (fn (src, _, _) => src)
849+
(List.filter (fn (_, left, right) => left = node_id orelse right = node_id) edges)
850+
851+
(* Check if a node can be a rule starting point *)
852+
fun can_be_rule_start node_id =
853+
if node_id = root_id then
854+
true (* Root can always start a rule *)
855+
else
856+
let
857+
(* Condition 1: All parents reach this node via RIGHT branch *)
858+
val incoming_edges = List.filter (fn (_, left, right) =>
859+
left = node_id orelse right = node_id) edges
860+
861+
val all_via_right = List.all (fn (_, left, right) => right = node_id) incoming_edges
862+
863+
(* Condition 2: Node's left child has only this node as parent (via left branch) *)
864+
val left_child = get_left_child node_id
865+
val left_child_parents = get_parents left_child
866+
867+
(* Check if left child has exactly one parent AND it's this node via left branch *)
868+
val left_child_has_single_parent =
869+
length left_child_parents = 1 andalso
870+
hd left_child_parents = node_id andalso
871+
List.exists (fn (src, left, _) => src = node_id andalso left = left_child) edges
872+
in
873+
all_via_right andalso left_child_has_single_parent
874+
end
875+
876+
(* Find starting nodes for rules - nodes that can start rules *)
877+
fun find_rule_starting_nodes target_terminal =
878+
let
879+
(* Find all nodes that point to target_terminal via LEFT branch *)
880+
val left_parents = List.map (fn (src, left, _) => src)
881+
(List.filter (fn (_, left, _) => left = target_terminal) edges)
882+
883+
(* For each left parent, find the closest ancestor that can start a rule *)
884+
fun find_starting_node_for_parent parent_id =
885+
let
886+
fun trace_back node_id =
887+
if can_be_rule_start node_id then
888+
node_id (* Found a valid starting point *)
889+
else if node_id = root_id then
890+
root_id (* Reached root, use it as starting point *)
891+
else
892+
(* Find parent and continue *)
893+
let
894+
val parents = get_parents node_id
895+
in
896+
case parents of
897+
[parent] => trace_back parent
898+
| _ => raise Fail ("Multiple or no parents for node " ^ Int.toString node_id)
899+
end
900+
in
901+
trace_back parent_id
902+
end
903+
904+
(* Also check if terminal is directly at root *)
905+
val root_start =
906+
if target_terminal = root_id then
907+
[root_id]
908+
else if List.exists (fn (src, left, _) => src = root_id andalso left = target_terminal) edges then
909+
[root_id]
910+
else []
911+
in
912+
(* Remove duplicates from starting nodes *)
913+
let
914+
val all_starts = root_start @ (map find_starting_node_for_parent left_parents)
915+
fun remove_dups [] = []
916+
| remove_dups (x::xs) =
917+
if List.exists (fn y => y = x) xs then remove_dups xs
918+
else x::remove_dups xs
919+
in
920+
remove_dups all_starts
921+
end
922+
end
923+
924+
(* Find all paths from a starting node to target terminal *)
925+
fun find_paths_from_start start_id target_id =
926+
let
927+
fun dfs current_id current_path =
928+
if current_id = target_id then
929+
[List.rev current_path] (* Found target *)
930+
else
931+
let
932+
val node_data = case List.find (fn (id, _) => id = current_id) labels of
933+
SOME (_, data) => data
934+
| NONE => raise Fail ("Node not found")
935+
in
936+
if is_terminal node_data then
937+
[] (* Different terminal *)
938+
else
939+
let
940+
val (left_child, right_child) =
941+
case List.find (fn (src, l, r) => src = current_id) edges of
942+
SOME (_, l, r) => (l, r)
943+
| NONE => raise Fail ("No edges from node")
944+
945+
val var_name = get_var_name node_data
946+
947+
(* Try left branch *)
948+
val left_paths = dfs left_child ((current_id, var_name, true)::current_path)
949+
950+
(* Try right branch *)
951+
val right_paths = dfs right_child ((current_id, var_name, false)::current_path)
952+
in
953+
left_paths @ right_paths
954+
end
955+
end
956+
in
957+
dfs start_id []
958+
end
959+
960+
(* Build rule for a terminal (except last one) *)
961+
fun build_rule_for_terminal term_id =
962+
let
963+
val term_data = case List.find (fn (id, _) => id = term_id) labels of
964+
SOME (_, data) => data
965+
| NONE => raise Fail ("Terminal not found")
966+
val action = get_action term_data
967+
968+
(* Find starting nodes for rules *)
969+
val starting_nodes = find_rule_starting_nodes term_id
970+
971+
(* For each starting node, find paths to terminal *)
972+
val all_paths =
973+
List.concat (map (fn start_id => find_paths_from_start start_id term_id) starting_nodes)
974+
975+
(* Convert a path to predicate: AND of variables where we took LEFT branch *)
976+
fun path_to_predicate path =
977+
let
978+
val positive_vars = List.map (fn (_, var_name, _) => var_name)
979+
(List.filter (fn (_, _, decision) => decision) path)
980+
981+
fun build_and [] = ``(True : pred)``
982+
| build_and [var] = ``(Var ^(stringSyntax.fromMLstring var)) : pred``
983+
| build_and (var::vars) =
984+
let
985+
val first_pred = ``(Var ^(stringSyntax.fromMLstring var)) : pred``
986+
val rest_pred = build_and vars
987+
in
988+
if aconv rest_pred ``(True : pred)`` then
989+
first_pred
990+
else
991+
``(And ^first_pred ^rest_pred) : pred``
992+
end
993+
in
994+
build_and positive_vars
995+
end
996+
997+
(* Convert all paths to predicates *)
998+
val predicates = map path_to_predicate all_paths
999+
1000+
(* Remove duplicate predicates *)
1001+
val unique_predicates =
1002+
let
1003+
fun remove_dups [] = []
1004+
| remove_dups (p::ps) =
1005+
if List.exists (fn q => aconv p q) ps then remove_dups ps
1006+
else p::remove_dups ps
1007+
in
1008+
remove_dups predicates
1009+
end
1010+
1011+
(* Combine unique predicates with OR *)
1012+
val final_predicate =
1013+
case unique_predicates of
1014+
[] => ``(True : pred)``
1015+
| [p] => p
1016+
| p::ps => List.foldl (fn (pred, acc) => ``(Or ^acc ^pred) : pred``) p ps
1017+
in
1018+
(final_predicate, action)
1019+
end
1020+
1021+
(* Build all rules - last terminal gets True *)
1022+
fun build_rules [] = []
1023+
| build_rules terminals =
1024+
let
1025+
val num_terms = length terminals
1026+
fun build idx remaining =
1027+
case remaining of
1028+
[] => []
1029+
| [term_id] => (* Last terminal gets True *)
1030+
let
1031+
val term_data = case List.find (fn (id, _) => id = term_id) labels of
1032+
SOME (_, data) => data
1033+
| NONE => raise Fail ("Terminal not found")
1034+
val action = get_action term_data
1035+
in
1036+
[(``(True : pred)``, action)]
1037+
end
1038+
| term_id::rest =>
1039+
let
1040+
val rule = build_rule_for_terminal term_id
1041+
in
1042+
rule :: build (idx+1) rest
1043+
end
1044+
in
1045+
build 0 terminals
1046+
end
1047+
1048+
val rules = build_rules unique_terminals
1049+
1050+
val rule_terms =
1051+
if null rules then
1052+
listSyntax.mk_list ([], ``:pred # action``)
1053+
else
1054+
let
1055+
val rule_pairs = map (fn (pred, act) => mk_pair (pred, act)) rules
1056+
val pair_type = type_of (hd rule_pairs)
1057+
in
1058+
listSyntax.mk_list (rule_pairs, pair_type)
1059+
end
1060+
1061+
in
1062+
rule_terms
1063+
end
1064+
1065+
1066+
7371067
end
7381068

7391069
(*

0 commit comments

Comments
 (0)