Skip to content

Commit c2210d0

Browse files
committed
Add node_names
1 parent 8b7f591 commit c2210d0

File tree

1 file changed

+9
-149
lines changed
  • causallearn/search/ConstraintBased

1 file changed

+9
-149
lines changed

causallearn/search/ConstraintBased/FCI.py

Lines changed: 9 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from causallearn.utils.cit import *
1616
from causallearn.utils.FAS import fas
1717
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
18-
from itertools import combinations
18+
1919

2020
def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
2121
if node == edge.get_node1():
@@ -320,9 +320,8 @@ def rulesR1R2cycle(graph: Graph, bk: BackgroundKnowledge | None, changeFlag: boo
320320

321321
def isNoncollider(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], node_i: Node, node_j: Node,
322322
node_k: Node) -> bool:
323-
node_map = graph.get_node_map()
324-
sep_set = sep_sets.get((node_map[node_i], node_map[node_k]))
325-
return sep_set is not None and sep_set.__contains__(node_map[node_j])
323+
sep_set = sep_sets[(graph.get_node_map()[node_i], graph.get_node_map()[node_k])]
324+
return sep_set is not None and sep_set.__contains__(graph.get_node_map()[node_j])
326325

327326

328327
def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: BackgroundKnowledge | None, changeFlag: bool,
@@ -543,142 +542,6 @@ def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_m
543542
return change_flag
544543

545544

546-
547-
def rule8(graph: Graph, nodes: List[Node]):
548-
nodes = graph.get_nodes()
549-
changeFlag = False
550-
for node_B in nodes:
551-
adj = graph.get_adjacent_nodes(node_B)
552-
if len(adj) < 2:
553-
continue
554-
555-
cg = ChoiceGenerator(len(adj), 2)
556-
combination = cg.next()
557-
558-
while combination is not None:
559-
node_A = adj[combination[0]]
560-
node_C = adj[combination[1]]
561-
combination = cg.next()
562-
563-
if(graph.get_endpoint(node_A, node_B) == Endpoint.ARROW and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \
564-
graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \
565-
graph.is_adjacent_to(node_A, node_C) and \
566-
graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE) or \
567-
(graph.get_endpoint(node_A, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \
568-
graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \
569-
graph.is_adjacent_to(node_A, node_C) and \
570-
graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE):
571-
edge1 = graph.get_edge(node_A, node_C)
572-
graph.remove_edge(edge1)
573-
graph.add_edge(Edge(node_A, node_C,Endpoint.TAIL, Endpoint.ARROW))
574-
changeFlag = True
575-
576-
return changeFlag
577-
578-
579-
580-
def is_possible_parent(graph: Graph, potential_parent_node, child_node):
581-
if graph.node_map[potential_parent_node] == graph.node_map[child_node]:
582-
return False
583-
if not graph.is_adjacent_to(potential_parent_node, child_node):
584-
return False
585-
586-
if graph.get_endpoint(child_node, potential_parent_node) == Endpoint.ARROW or \
587-
graph.get_endpoint(potential_parent_node, child_node) == Endpoint.TAIL:
588-
return False
589-
else:
590-
return True
591-
592-
593-
def find_possible_children(graph: Graph, parent_node, en_nodes=None):
594-
if en_nodes is None:
595-
nodes = graph.get_nodes()
596-
en_nodes = [node for node in nodes if graph.node_map[node] != graph.node_map[parent_node]]
597-
598-
potential_child_nodes = set()
599-
for potential_node in en_nodes:
600-
if is_possible_parent(graph, potential_parent_node=parent_node, child_node=potential_node):
601-
potential_child_nodes.add(potential_node)
602-
603-
return potential_child_nodes
604-
605-
def rule9(graph: Graph, nodes: List[Node]):
606-
changeFlag = False
607-
nodes = graph.get_nodes()
608-
for node_C in nodes:
609-
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
610-
for node_A in intoCArrows:
611-
# we want A o--> C
612-
if not graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE:
613-
continue
614-
615-
# look for a possibly directed uncovered path s.t. B and C are not connected (for the given A o--> C
616-
a_node_idx = graph.node_map[node_A]
617-
c_node_idx = graph.node_map[node_C]
618-
a_adj_nodes = graph.get_adjacent_nodes(node_A)
619-
nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= c_node_idx]
620-
possible_children = find_possible_children(graph, node_A, nodes_set)
621-
for node_B in possible_children:
622-
if graph.is_adjacent_to(node_B, node_C):
623-
continue
624-
if existsSemiDirectedPath(node_from=node_B, node_to=node_C, G=graph):
625-
edge1 = graph.get_edge(node_A, node_C)
626-
graph.remove_edge(edge1)
627-
graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))
628-
changeFlag = True
629-
break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A
630-
return changeFlag
631-
632-
633-
def rule10(graph: Graph):
634-
changeFlag = False
635-
nodes = graph.get_nodes()
636-
for node_C in nodes:
637-
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
638-
if len(intoCArrows) < 2:
639-
continue
640-
# get all A where A o-> C
641-
Anodes = [node_A for node_A in intoCArrows if graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE]
642-
if len(Anodes) == 0:
643-
continue
644-
645-
for node_A in Anodes:
646-
A_adj_nodes = graph.get_adjacent_nodes(node_A)
647-
en_nodes = [i for i in A_adj_nodes if i is not node_C]
648-
A_possible_children = find_possible_children(graph, parent_node=node_A, en_nodes=en_nodes)
649-
if len(A_possible_children) < 2:
650-
continue
651-
652-
gen = ChoiceGenerator(len(intoCArrows), 2)
653-
choice = gen.next()
654-
while choice is not None:
655-
node_B = intoCArrows[choice[0]]
656-
node_D = intoCArrows[choice[1]]
657-
658-
choice = gen.next()
659-
# we want B->C<-D
660-
if graph.get_endpoint(node_C, node_B) != Endpoint.TAIL:
661-
continue
662-
663-
if graph.get_endpoint(node_C, node_D) != Endpoint.TAIL:
664-
continue
665-
666-
for children in combinations(A_possible_children, 2):
667-
child_one, child_two = children
668-
if not existsSemiDirectedPath(node_from=child_one, node_to=node_B, G=graph) or \
669-
not existsSemiDirectedPath(node_from=child_two, node_to=node_D, G=graph):
670-
continue
671-
672-
if not graph.is_adjacent_to(child_one, child_two):
673-
edge1 = graph.get_edge(node_A, node_C)
674-
graph.remove_edge(edge1)
675-
graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))
676-
changeFlag = True
677-
break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A
678-
679-
return changeFlag
680-
681-
682545
def visibleEdgeHelperVisit(graph: Graph, node_c: Node, node_a: Node, node_b: Node, path: List[Node]) -> bool:
683546
if path.__contains__(node_a):
684547
return False
@@ -828,8 +691,10 @@ def _contains_all(set_a: Set[Node], set_b: Set[Node]):
828691
break
829692

830693

694+
831695
def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = 0.05, depth: int = -1,
832-
max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, show_progress: bool = True,
696+
max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None,
697+
show_progress: bool = True, node_names = None,
833698
**kwargs) -> Tuple[Graph, List[Edge]]:
834699
"""
835700
Perform Fast Causal Inference (FCI) algorithm for causal discovery
@@ -884,8 +749,10 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
884749

885750

886751
nodes = []
752+
if node_names is None:
753+
node_names = [f"X{i + 1}" for i in range(dataset.shape[1])]
887754
for i in range(dataset.shape[1]):
888-
node = GraphNode(f"X{i + 1}")
755+
node = GraphNode(node_names[i])
889756
node.add_attribute("id", i)
890757
nodes.append(node)
891758

@@ -923,13 +790,6 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
923790
if verbose:
924791
print("Epoch")
925792

926-
# rule 8
927-
change_flag = rule8(graph,nodes)
928-
# rule 9
929-
change_flag = rule9(graph, nodes)
930-
# rule 10
931-
change_flag = rule10(graph)
932-
933793
graph.set_pag(True)
934794

935795
edges = get_color_edges(graph)

0 commit comments

Comments
 (0)