|
15 | 15 | from causallearn.utils.cit import * |
16 | 16 | from causallearn.utils.FAS import fas |
17 | 17 | from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge |
18 | | -from itertools import combinations |
| 18 | + |
19 | 19 |
|
20 | 20 | def traverseSemiDirected(node: Node, edge: Edge) -> Node | None: |
21 | 21 | if node == edge.get_node1(): |
@@ -320,9 +320,8 @@ def rulesR1R2cycle(graph: Graph, bk: BackgroundKnowledge | None, changeFlag: boo |
320 | 320 |
|
321 | 321 | def isNoncollider(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], node_i: Node, node_j: Node, |
322 | 322 | node_k: Node) -> bool: |
323 | | - node_map = graph.get_node_map() |
324 | | - sep_set = sep_sets.get((node_map[node_i], node_map[node_k])) |
325 | | - return sep_set is not None and sep_set.__contains__(node_map[node_j]) |
| 323 | + sep_set = sep_sets[(graph.get_node_map()[node_i], graph.get_node_map()[node_k])] |
| 324 | + return sep_set is not None and sep_set.__contains__(graph.get_node_map()[node_j]) |
326 | 325 |
|
327 | 326 |
|
328 | 327 | def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: BackgroundKnowledge | None, changeFlag: bool, |
@@ -543,142 +542,6 @@ def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_m |
543 | 542 | return change_flag |
544 | 543 |
|
545 | 544 |
|
546 | | - |
547 | | -def rule8(graph: Graph, nodes: List[Node]): |
548 | | - nodes = graph.get_nodes() |
549 | | - changeFlag = False |
550 | | - for node_B in nodes: |
551 | | - adj = graph.get_adjacent_nodes(node_B) |
552 | | - if len(adj) < 2: |
553 | | - continue |
554 | | - |
555 | | - cg = ChoiceGenerator(len(adj), 2) |
556 | | - combination = cg.next() |
557 | | - |
558 | | - while combination is not None: |
559 | | - node_A = adj[combination[0]] |
560 | | - node_C = adj[combination[1]] |
561 | | - combination = cg.next() |
562 | | - |
563 | | - if(graph.get_endpoint(node_A, node_B) == Endpoint.ARROW and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \ |
564 | | - graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \ |
565 | | - graph.is_adjacent_to(node_A, node_C) and \ |
566 | | - graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE) or \ |
567 | | - (graph.get_endpoint(node_A, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \ |
568 | | - graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \ |
569 | | - graph.is_adjacent_to(node_A, node_C) and \ |
570 | | - graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE): |
571 | | - edge1 = graph.get_edge(node_A, node_C) |
572 | | - graph.remove_edge(edge1) |
573 | | - graph.add_edge(Edge(node_A, node_C,Endpoint.TAIL, Endpoint.ARROW)) |
574 | | - changeFlag = True |
575 | | - |
576 | | - return changeFlag |
577 | | - |
578 | | - |
579 | | - |
580 | | -def is_possible_parent(graph: Graph, potential_parent_node, child_node): |
581 | | - if graph.node_map[potential_parent_node] == graph.node_map[child_node]: |
582 | | - return False |
583 | | - if not graph.is_adjacent_to(potential_parent_node, child_node): |
584 | | - return False |
585 | | - |
586 | | - if graph.get_endpoint(child_node, potential_parent_node) == Endpoint.ARROW or \ |
587 | | - graph.get_endpoint(potential_parent_node, child_node) == Endpoint.TAIL: |
588 | | - return False |
589 | | - else: |
590 | | - return True |
591 | | - |
592 | | - |
593 | | -def find_possible_children(graph: Graph, parent_node, en_nodes=None): |
594 | | - if en_nodes is None: |
595 | | - nodes = graph.get_nodes() |
596 | | - en_nodes = [node for node in nodes if graph.node_map[node] != graph.node_map[parent_node]] |
597 | | - |
598 | | - potential_child_nodes = set() |
599 | | - for potential_node in en_nodes: |
600 | | - if is_possible_parent(graph, potential_parent_node=parent_node, child_node=potential_node): |
601 | | - potential_child_nodes.add(potential_node) |
602 | | - |
603 | | - return potential_child_nodes |
604 | | - |
605 | | -def rule9(graph: Graph, nodes: List[Node]): |
606 | | - changeFlag = False |
607 | | - nodes = graph.get_nodes() |
608 | | - for node_C in nodes: |
609 | | - intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW) |
610 | | - for node_A in intoCArrows: |
611 | | - # we want A o--> C |
612 | | - if not graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE: |
613 | | - continue |
614 | | - |
615 | | - # look for a possibly directed uncovered path s.t. B and C are not connected (for the given A o--> C |
616 | | - a_node_idx = graph.node_map[node_A] |
617 | | - c_node_idx = graph.node_map[node_C] |
618 | | - a_adj_nodes = graph.get_adjacent_nodes(node_A) |
619 | | - nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= c_node_idx] |
620 | | - possible_children = find_possible_children(graph, node_A, nodes_set) |
621 | | - for node_B in possible_children: |
622 | | - if graph.is_adjacent_to(node_B, node_C): |
623 | | - continue |
624 | | - if existsSemiDirectedPath(node_from=node_B, node_to=node_C, G=graph): |
625 | | - edge1 = graph.get_edge(node_A, node_C) |
626 | | - graph.remove_edge(edge1) |
627 | | - graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW)) |
628 | | - changeFlag = True |
629 | | - break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A |
630 | | - return changeFlag |
631 | | - |
632 | | - |
633 | | -def rule10(graph: Graph): |
634 | | - changeFlag = False |
635 | | - nodes = graph.get_nodes() |
636 | | - for node_C in nodes: |
637 | | - intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW) |
638 | | - if len(intoCArrows) < 2: |
639 | | - continue |
640 | | - # get all A where A o-> C |
641 | | - Anodes = [node_A for node_A in intoCArrows if graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE] |
642 | | - if len(Anodes) == 0: |
643 | | - continue |
644 | | - |
645 | | - for node_A in Anodes: |
646 | | - A_adj_nodes = graph.get_adjacent_nodes(node_A) |
647 | | - en_nodes = [i for i in A_adj_nodes if i is not node_C] |
648 | | - A_possible_children = find_possible_children(graph, parent_node=node_A, en_nodes=en_nodes) |
649 | | - if len(A_possible_children) < 2: |
650 | | - continue |
651 | | - |
652 | | - gen = ChoiceGenerator(len(intoCArrows), 2) |
653 | | - choice = gen.next() |
654 | | - while choice is not None: |
655 | | - node_B = intoCArrows[choice[0]] |
656 | | - node_D = intoCArrows[choice[1]] |
657 | | - |
658 | | - choice = gen.next() |
659 | | - # we want B->C<-D |
660 | | - if graph.get_endpoint(node_C, node_B) != Endpoint.TAIL: |
661 | | - continue |
662 | | - |
663 | | - if graph.get_endpoint(node_C, node_D) != Endpoint.TAIL: |
664 | | - continue |
665 | | - |
666 | | - for children in combinations(A_possible_children, 2): |
667 | | - child_one, child_two = children |
668 | | - if not existsSemiDirectedPath(node_from=child_one, node_to=node_B, G=graph) or \ |
669 | | - not existsSemiDirectedPath(node_from=child_two, node_to=node_D, G=graph): |
670 | | - continue |
671 | | - |
672 | | - if not graph.is_adjacent_to(child_one, child_two): |
673 | | - edge1 = graph.get_edge(node_A, node_C) |
674 | | - graph.remove_edge(edge1) |
675 | | - graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW)) |
676 | | - changeFlag = True |
677 | | - break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A |
678 | | - |
679 | | - return changeFlag |
680 | | - |
681 | | - |
682 | 545 | def visibleEdgeHelperVisit(graph: Graph, node_c: Node, node_a: Node, node_b: Node, path: List[Node]) -> bool: |
683 | 546 | if path.__contains__(node_a): |
684 | 547 | return False |
@@ -828,8 +691,10 @@ def _contains_all(set_a: Set[Node], set_b: Set[Node]): |
828 | 691 | break |
829 | 692 |
|
830 | 693 |
|
| 694 | + |
831 | 695 | def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = 0.05, depth: int = -1, |
832 | | - max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, show_progress: bool = True, |
| 696 | + max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, |
| 697 | + show_progress: bool = True, node_names = None, |
833 | 698 | **kwargs) -> Tuple[Graph, List[Edge]]: |
834 | 699 | """ |
835 | 700 | Perform Fast Causal Inference (FCI) algorithm for causal discovery |
@@ -884,8 +749,10 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = |
884 | 749 |
|
885 | 750 |
|
886 | 751 | nodes = [] |
| 752 | + if node_names is None: |
| 753 | + node_names = [f"X{i + 1}" for i in range(dataset.shape[1])] |
887 | 754 | for i in range(dataset.shape[1]): |
888 | | - node = GraphNode(f"X{i + 1}") |
| 755 | + node = GraphNode(node_names[i]) |
889 | 756 | node.add_attribute("id", i) |
890 | 757 | nodes.append(node) |
891 | 758 |
|
@@ -923,13 +790,6 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = |
923 | 790 | if verbose: |
924 | 791 | print("Epoch") |
925 | 792 |
|
926 | | - # rule 8 |
927 | | - change_flag = rule8(graph,nodes) |
928 | | - # rule 9 |
929 | | - change_flag = rule9(graph, nodes) |
930 | | - # rule 10 |
931 | | - change_flag = rule10(graph) |
932 | | - |
933 | 793 | graph.set_pag(True) |
934 | 794 |
|
935 | 795 | edges = get_color_edges(graph) |
|
0 commit comments