Skip to content

Commit 8c72b52

Browse files
authored
Merge pull request #206 from sherryzyh/main
Graph operations compatible with np array
2 parents 08655cd + 705bcec commit 8c72b52

File tree

3 files changed

+40
-11
lines changed

3 files changed

+40
-11
lines changed

causallearn/graph/Dag.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
from itertools import combinations
3-
from typing import List
3+
from typing import List, Optional, Union
44

55
import networkx as nx
66
import numpy as np
@@ -18,8 +18,23 @@
1818
# or latent, with at most one edge per node pair, and no edges to self.
1919
class Dag(GeneralGraph):
2020

21-
def __init__(self, nodes: List[Node]):
22-
21+
def __init__(self, nodes: Optional[List[Node]]=None, graph: Union[np.ndarray, nx.Graph, None]=None):
22+
if nodes is not None:
23+
self._init_from_nodes(nodes)
24+
elif graph is not None:
25+
if isinstance(graph, np.ndarray):
26+
nodes = [Node(node_name=str(i)) for i in range(len(graph))]
27+
self._init_from_nodes(nodes)
28+
for i in range(len(nodes)):
29+
for j in range(len(nodes)):
30+
if graph[i, j] == 1:
31+
self.add_directed_edge(nodes[i], nodes[j])
32+
else:
33+
pass
34+
else:
35+
raise ValueError("Dag.__init__() requires argument 'nodes' or 'graph'")
36+
37+
def _init_from_nodes(self, nodes: List[Node]):
2338
# for node in nodes:
2439
# if not isinstance(node, type(GraphNode)):
2540
# raise TypeError("Graphs must be instantiated with a list of GraphNodes")

causallearn/graph/Node.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,34 @@
22

33
# Represents an object with a name, node type, and position that can serve as a
44
# node in a graph.
5+
from typing import Optional
56
from causallearn.graph.NodeType import NodeType
67
from causallearn.graph.NodeVariableType import NodeVariableType
78

89

910
class Node:
11+
node_type: NodeType
12+
node_name: str
1013

14+
def __init__(self, node_name: Optional[str] = None, node_type: Optional[NodeType] = None) -> None:
15+
self.node_name = node_name
16+
self.node_type = node_type
17+
1118
# @return the name of the variable.
1219
def get_name(self) -> str:
13-
pass
20+
return self.node_name
1421

1522
# set the name of the variable
1623
def set_name(self, name: str):
17-
pass
24+
self.node_name = name
1825

1926
# @return the node type of the variable
2027
def get_node_type(self) -> NodeType:
21-
pass
28+
return self.node_type
2229

2330
# set the node type of the variable
2431
def set_node_type(self, node_type: NodeType):
25-
pass
32+
self.node_type = node_type
2633

2734
# @return the intervention type
2835
def get_node_variable_type(self) -> NodeVariableType:
@@ -35,7 +42,7 @@ def set_node_variable_type(self, var_type: NodeVariableType):
3542

3643
# @return the name of the node as its string representation
3744
def __str__(self):
38-
pass
45+
return self.node_name
3946

4047
# @return the x coordinate of the center of the node
4148
def get_center_x(self) -> int:
@@ -59,7 +66,7 @@ def set_center(self, center_x: int, center_y: int):
5966

6067
# @return a hashcode for this variable
6168
def __hash__(self):
62-
pass
69+
return hash(self.node_name)
6370

6471
# @return true iff this variable is equal to the given variable
6572
def __eq__(self, other):

causallearn/utils/DAG2CPDAG.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Union
12
import numpy as np
23

34
from causallearn.graph.Dag import Dag
@@ -6,7 +7,7 @@
67
from causallearn.graph.GeneralGraph import GeneralGraph
78

89

9-
def dag2cpdag(G: Dag) -> GeneralGraph:
10+
def dag2cpdag(G: Union[Dag, np.ndarray]) -> GeneralGraph:
1011
"""
1112
Convert a DAG to its corresponding PDAG
1213
@@ -22,7 +23,13 @@ def dag2cpdag(G: Dag) -> GeneralGraph:
2223
-------
2324
Yuequn Liu@dmirlab, Wei Chen@dmirlab, Kun Zhang@CMU
2425
"""
25-
26+
27+
if isinstance(G, np.ndarray):
28+
# convert np array to Dag graph
29+
G = Dag(graph=G)
30+
elif not isinstance(G, Dag):
31+
raise TypeError("parameter graph should be `Dag` or `np.ndarry`")
32+
2633
# order the edges in G
2734
nodes_order = list(
2835
map(lambda x: G.node_map[x], G.get_causal_ordering())

0 commit comments

Comments
 (0)