|
15 | 15 | from causallearn.utils.cit import * |
16 | 16 | from causallearn.utils.FAS import fas |
17 | 17 | from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge |
18 | | - |
| 18 | +from itertools import combinations |
19 | 19 |
|
20 | 20 | def traverseSemiDirected(node: Node, edge: Edge) -> Node | None: |
21 | 21 | if node == edge.get_node1(): |
@@ -320,8 +320,9 @@ def rulesR1R2cycle(graph: Graph, bk: BackgroundKnowledge | None, changeFlag: boo |
320 | 320 |
|
321 | 321 | def isNoncollider(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], node_i: Node, node_j: Node, |
322 | 322 | node_k: Node) -> bool: |
323 | | - sep_set = sep_sets[(graph.get_node_map()[node_i], graph.get_node_map()[node_k])] |
324 | | - return sep_set is not None and sep_set.__contains__(graph.get_node_map()[node_j]) |
| 323 | + node_map = graph.get_node_map() |
| 324 | + sep_set = sep_sets.get((node_map[node_i], node_map[node_k])) |
| 325 | + return sep_set is not None and sep_set.__contains__(node_map[node_j]) |
325 | 326 |
|
326 | 327 |
|
327 | 328 | def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: BackgroundKnowledge | None, changeFlag: bool, |
@@ -542,6 +543,142 @@ def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_m |
542 | 543 | return change_flag |
543 | 544 |
|
544 | 545 |
|
| 546 | + |
| 547 | +def rule8(graph: Graph, nodes: List[Node]): |
| 548 | + nodes = graph.get_nodes() |
| 549 | + changeFlag = False |
| 550 | + for node_B in nodes: |
| 551 | + adj = graph.get_adjacent_nodes(node_B) |
| 552 | + if len(adj) < 2: |
| 553 | + continue |
| 554 | + |
| 555 | + cg = ChoiceGenerator(len(adj), 2) |
| 556 | + combination = cg.next() |
| 557 | + |
| 558 | + while combination is not None: |
| 559 | + node_A = adj[combination[0]] |
| 560 | + node_C = adj[combination[1]] |
| 561 | + combination = cg.next() |
| 562 | + |
| 563 | + if(graph.get_endpoint(node_A, node_B) == Endpoint.ARROW and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \ |
| 564 | + graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \ |
| 565 | + graph.is_adjacent_to(node_A, node_C) and \ |
| 566 | + graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE) or \ |
| 567 | + (graph.get_endpoint(node_A, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \ |
| 568 | + graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \ |
| 569 | + graph.is_adjacent_to(node_A, node_C) and \ |
| 570 | + graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE): |
| 571 | + edge1 = graph.get_edge(node_A, node_C) |
| 572 | + graph.remove_edge(edge1) |
| 573 | + graph.add_edge(Edge(node_A, node_C,Endpoint.TAIL, Endpoint.ARROW)) |
| 574 | + changeFlag = True |
| 575 | + |
| 576 | + return changeFlag |
| 577 | + |
| 578 | + |
| 579 | + |
| 580 | +def is_possible_parent(graph: Graph, potential_parent_node, child_node): |
| 581 | + if graph.node_map[potential_parent_node] == graph.node_map[child_node]: |
| 582 | + return False |
| 583 | + if not graph.is_adjacent_to(potential_parent_node, child_node): |
| 584 | + return False |
| 585 | + |
| 586 | + if graph.get_endpoint(child_node, potential_parent_node) == Endpoint.ARROW or \ |
| 587 | + graph.get_endpoint(potential_parent_node, child_node) == Endpoint.TAIL: |
| 588 | + return False |
| 589 | + else: |
| 590 | + return True |
| 591 | + |
| 592 | + |
| 593 | +def find_possible_children(graph: Graph, parent_node, en_nodes=None): |
| 594 | + if en_nodes is None: |
| 595 | + nodes = graph.get_nodes() |
| 596 | + en_nodes = [node for node in nodes if graph.node_map[node] != graph.node_map[parent_node]] |
| 597 | + |
| 598 | + potential_child_nodes = set() |
| 599 | + for potential_node in en_nodes: |
| 600 | + if is_possible_parent(graph, potential_parent_node=parent_node, child_node=potential_node): |
| 601 | + potential_child_nodes.add(potential_node) |
| 602 | + |
| 603 | + return potential_child_nodes |
| 604 | + |
| 605 | +def rule9(graph: Graph, nodes: List[Node]): |
| 606 | + changeFlag = False |
| 607 | + nodes = graph.get_nodes() |
| 608 | + for node_C in nodes: |
| 609 | + intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW) |
| 610 | + for node_A in intoCArrows: |
| 611 | + # we want A o--> C |
| 612 | + if not graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE: |
| 613 | + continue |
| 614 | + |
| 615 | + # look for a possibly directed uncovered path s.t. B and C are not connected (for the given A o--> C |
| 616 | + a_node_idx = graph.node_map[node_A] |
| 617 | + c_node_idx = graph.node_map[node_C] |
| 618 | + a_adj_nodes = graph.get_adjacent_nodes(node_A) |
| 619 | + nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= c_node_idx] |
| 620 | + possible_children = find_possible_children(graph, node_A, nodes_set) |
| 621 | + for node_B in possible_children: |
| 622 | + if graph.is_adjacent_to(node_B, node_C): |
| 623 | + continue |
| 624 | + if existsSemiDirectedPath(node_from=node_B, node_to=node_C, G=graph): |
| 625 | + edge1 = graph.get_edge(node_A, node_C) |
| 626 | + graph.remove_edge(edge1) |
| 627 | + graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW)) |
| 628 | + changeFlag = True |
| 629 | + break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A |
| 630 | + return changeFlag |
| 631 | + |
| 632 | + |
| 633 | +def rule10(graph: Graph): |
| 634 | + changeFlag = False |
| 635 | + nodes = graph.get_nodes() |
| 636 | + for node_C in nodes: |
| 637 | + intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW) |
| 638 | + if len(intoCArrows) < 2: |
| 639 | + continue |
| 640 | + # get all A where A o-> C |
| 641 | + Anodes = [node_A for node_A in intoCArrows if graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE] |
| 642 | + if len(Anodes) == 0: |
| 643 | + continue |
| 644 | + |
| 645 | + for node_A in Anodes: |
| 646 | + A_adj_nodes = graph.get_adjacent_nodes(node_A) |
| 647 | + en_nodes = [i for i in A_adj_nodes if i is not node_C] |
| 648 | + A_possible_children = find_possible_children(graph, parent_node=node_A, en_nodes=en_nodes) |
| 649 | + if len(A_possible_children) < 2: |
| 650 | + continue |
| 651 | + |
| 652 | + gen = ChoiceGenerator(len(intoCArrows), 2) |
| 653 | + choice = gen.next() |
| 654 | + while choice is not None: |
| 655 | + node_B = intoCArrows[choice[0]] |
| 656 | + node_D = intoCArrows[choice[1]] |
| 657 | + |
| 658 | + choice = gen.next() |
| 659 | + # we want B->C<-D |
| 660 | + if graph.get_endpoint(node_C, node_B) != Endpoint.TAIL: |
| 661 | + continue |
| 662 | + |
| 663 | + if graph.get_endpoint(node_C, node_D) != Endpoint.TAIL: |
| 664 | + continue |
| 665 | + |
| 666 | + for children in combinations(A_possible_children, 2): |
| 667 | + child_one, child_two = children |
| 668 | + if not existsSemiDirectedPath(node_from=child_one, node_to=node_B, G=graph) or \ |
| 669 | + not existsSemiDirectedPath(node_from=child_two, node_to=node_D, G=graph): |
| 670 | + continue |
| 671 | + |
| 672 | + if not graph.is_adjacent_to(child_one, child_two): |
| 673 | + edge1 = graph.get_edge(node_A, node_C) |
| 674 | + graph.remove_edge(edge1) |
| 675 | + graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW)) |
| 676 | + changeFlag = True |
| 677 | + break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A |
| 678 | + |
| 679 | + return changeFlag |
| 680 | + |
| 681 | + |
545 | 682 | def visibleEdgeHelperVisit(graph: Graph, node_c: Node, node_a: Node, node_b: Node, path: List[Node]) -> bool: |
546 | 683 | if path.__contains__(node_a): |
547 | 684 | return False |
@@ -691,10 +828,8 @@ def _contains_all(set_a: Set[Node], set_b: Set[Node]): |
691 | 828 | break |
692 | 829 |
|
693 | 830 |
|
694 | | - |
695 | 831 | def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = 0.05, depth: int = -1, |
696 | | - max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, |
697 | | - show_progress: bool = True, node_names = None, |
| 832 | + max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, show_progress: bool = True, |
698 | 833 | **kwargs) -> Tuple[Graph, List[Edge]]: |
699 | 834 | """ |
700 | 835 | Perform Fast Causal Inference (FCI) algorithm for causal discovery |
@@ -749,10 +884,8 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = |
749 | 884 |
|
750 | 885 |
|
751 | 886 | nodes = [] |
752 | | - if node_names is None: |
753 | | - node_names = [f"X{i + 1}" for i in range(dataset.shape[1])] |
754 | 887 | for i in range(dataset.shape[1]): |
755 | | - node = GraphNode(node_names[i]) |
| 888 | + node = GraphNode(f"X{i + 1}") |
756 | 889 | node.add_attribute("id", i) |
757 | 890 | nodes.append(node) |
758 | 891 |
|
@@ -790,6 +923,13 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = |
790 | 923 | if verbose: |
791 | 924 | print("Epoch") |
792 | 925 |
|
| 926 | + # rule 8 |
| 927 | + change_flag = rule8(graph,nodes) |
| 928 | + # rule 9 |
| 929 | + change_flag = rule9(graph, nodes) |
| 930 | + # rule 10 |
| 931 | + change_flag = rule10(graph) |
| 932 | + |
793 | 933 | graph.set_pag(True) |
794 | 934 |
|
795 | 935 | edges = get_color_edges(graph) |
|
0 commit comments