Skip to content

Commit a74272f

Browse files
authored
Merge pull request #109 from MarkDana/fix-endpoint-comparison-bug
Fixed the endpoint comparison bug
2 parents 446b9a2 + b3beba7 commit a74272f

File tree

4 files changed

+35
-32
lines changed

4 files changed

+35
-32
lines changed

causallearn/graph/Edge.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -82,28 +82,28 @@ def set_endpoint1(self, endpoint: Endpoint):
8282
self.endpoint1 = endpoint
8383

8484
if self.numerical_endpoint_1 == 1 and self.numerical_endpoint_2 == 1:
85-
if endpoint is Endpoint.ARROW:
85+
if endpoint == Endpoint.ARROW:
8686
pass
8787
else:
88-
if endpoint is Endpoint.TAIL:
88+
if endpoint == Endpoint.TAIL:
8989
self.numerical_endpoint_1 = -1
9090
self.numerical_endpoint_2 = 1
9191
else:
92-
if endpoint is Endpoint.CIRCLE:
92+
if endpoint == Endpoint.CIRCLE:
9393
self.numerical_endpoint_1 = 2
9494
self.numerical_endpoint_2 = 1
9595
else:
96-
if endpoint is Endpoint.ARROW and self.numerical_endpoint_2 == 1:
96+
if endpoint == Endpoint.ARROW and self.numerical_endpoint_2 == 1:
9797
self.numerical_endpoint_1 = 1
9898
self.numerical_endpoint_2 = 1
9999
else:
100-
if endpoint is Endpoint.ARROW:
100+
if endpoint == Endpoint.ARROW:
101101
self.numerical_endpoint_1 = 1
102102
else:
103-
if endpoint is Endpoint.TAIL:
103+
if endpoint == Endpoint.TAIL:
104104
self.numerical_endpoint_1 = -1
105105
else:
106-
if endpoint is Endpoint.CIRCLE:
106+
if endpoint == Endpoint.CIRCLE:
107107
self.numerical_endpoint_1 = 2
108108

109109
if self.pointing_left(self.endpoint1, self.endpoint2):
@@ -123,28 +123,28 @@ def set_endpoint2(self, endpoint: Endpoint):
123123
self.endpoint2 = endpoint
124124

125125
if self.numerical_endpoint_1 == 1 and self.numerical_endpoint_2 == 1:
126-
if endpoint is Endpoint.ARROW:
126+
if endpoint == Endpoint.ARROW:
127127
pass
128128
else:
129-
if endpoint is Endpoint.TAIL:
129+
if endpoint == Endpoint.TAIL:
130130
self.numerical_endpoint_1 = 1
131131
self.numerical_endpoint_2 = -1
132132
else:
133-
if endpoint is Endpoint.CIRCLE:
133+
if endpoint == Endpoint.CIRCLE:
134134
self.numerical_endpoint_1 = 1
135135
self.numerical_endpoint_2 = 2
136136
else:
137-
if endpoint is Endpoint.ARROW and self.numerical_endpoint_2 == 1:
137+
if endpoint == Endpoint.ARROW and self.numerical_endpoint_2 == 1:
138138
self.numerical_endpoint_1 = 1
139139
self.numerical_endpoint_2 = 1
140140
else:
141-
if endpoint is Endpoint.ARROW:
141+
if endpoint == Endpoint.ARROW:
142142
self.numerical_endpoint_2 = 1
143143
else:
144-
if endpoint is Endpoint.TAIL:
144+
if endpoint == Endpoint.TAIL:
145145
self.numerical_endpoint_2 = -1
146146
else:
147-
if endpoint is Endpoint.CIRCLE:
147+
if endpoint == Endpoint.CIRCLE:
148148
self.numerical_endpoint_2 = 2
149149

150150
if self.pointing_left(self.endpoint1, self.endpoint2):
@@ -216,20 +216,20 @@ def __str__(self):
216216

217217
edge_string = node1.get_name() + " "
218218

219-
if endpoint1 is Endpoint.TAIL:
219+
if endpoint1 == Endpoint.TAIL:
220220
edge_string = edge_string + "-"
221221
else:
222-
if endpoint1 is Endpoint.ARROW:
222+
if endpoint1 == Endpoint.ARROW:
223223
edge_string = edge_string + "<"
224224
else:
225225
edge_string = edge_string + "o"
226226

227227
edge_string = edge_string + "-"
228228

229-
if endpoint2 is Endpoint.TAIL:
229+
if endpoint2 == Endpoint.TAIL:
230230
edge_string = edge_string + "-"
231231
else:
232-
if endpoint2 is Endpoint.ARROW:
232+
if endpoint2 == Endpoint.ARROW:
233233
edge_string = edge_string + ">"
234234
else:
235235
edge_string = edge_string + "o"

causallearn/graph/Edges.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,29 +30,29 @@ def undirected_edge(self, node_a: Node, node_b: Node) -> Edge:
3030

3131
# return true iff an edge is a bidrected edge <->
3232
def is_bidirected_edge(self, edge: Edge) -> bool:
33-
return edge.get_endpoint1() is Endpoint.ARROW and edge.get_endpoint2() is Endpoint.ARROW
33+
return edge.get_endpoint1() == Endpoint.ARROW and edge.get_endpoint2() == Endpoint.ARROW
3434

3535
# return true iff the given edge is a directed edge -->
3636
def is_directed_edge(self, edge: Edge) -> bool:
37-
if edge.get_endpoint1() is Endpoint.TAIL:
38-
return edge.get_endpoint2() is Endpoint.ARROW
39-
elif edge.get_endpoint2() is Endpoint.TAIL:
40-
return edge.get_endpoint1() is Endpoint.ARROW
37+
if edge.get_endpoint1() == Endpoint.TAIL:
38+
return edge.get_endpoint2() == Endpoint.ARROW
39+
elif edge.get_endpoint2() == Endpoint.TAIL:
40+
return edge.get_endpoint1() == Endpoint.ARROW
4141
else:
4242
return False
4343

4444
# return true iff the given edge is a partially oriented edge o->
4545
def is_partially_oriented_edge(self, edge: Edge) -> bool:
46-
if edge.get_endpoint1() is Endpoint.CIRCLE:
47-
return edge.get_endpoint2() is Endpoint.ARROW
48-
elif edge.get_endpoint2() is Endpoint.CIRCLE:
49-
return edge.get_endpoint1() is Endpoint.ARROW
46+
if edge.get_endpoint1() == Endpoint.CIRCLE:
47+
return edge.get_endpoint2() == Endpoint.ARROW
48+
elif edge.get_endpoint2() == Endpoint.CIRCLE:
49+
return edge.get_endpoint1() == Endpoint.ARROW
5050
else:
5151
return False
5252

5353
# return true iff some edge is an undirected edge --
5454
def is_undirected_edge(self, edge: Edge) -> bool:
55-
return edge.get_endpoint1() is Endpoint.TAIL and edge.get_endpoint2() is Endpoint.TAIL
55+
return edge.get_endpoint1() == Endpoint.TAIL and edge.get_endpoint2() == Endpoint.TAIL
5656

5757
def traverse_directed(self, node: Node, edge: Edge) -> Node | None:
5858
if node == edge.get_node1():

causallearn/graph/Endpoint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ class Endpoint(Enum):
1818
# Prints out the name of the type
1919
def __str__(self):
2020
return self.name
21+
22+
def __eq__(self, other):
23+
return self.value == other.value

causallearn/utils/GraphUtils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,20 @@ def edge_string(self, edge: Edge) -> str:
6262

6363
edge_string = node1.get_name() + " "
6464

65-
if endpoint1 is Endpoint.TAIL:
65+
if endpoint1 == Endpoint.TAIL:
6666
edge_string = edge_string + "-"
6767
else:
68-
if endpoint1 is Endpoint.ARROW:
68+
if endpoint1 == Endpoint.ARROW:
6969
edge_string = edge_string + "<"
7070
else:
7171
edge_string = edge_string + "o"
7272

7373
edge_string = edge_string + "-"
7474

75-
if endpoint2 is Endpoint.TAIL:
75+
if endpoint2 == Endpoint.TAIL:
7676
edge_string = edge_string + "-"
7777
else:
78-
if endpoint2 is Endpoint.ARROW:
78+
if endpoint2 == Endpoint.ARROW:
7979
edge_string = edge_string + ">"
8080
else:
8181
edge_string = edge_string + "o"

0 commit comments

Comments
 (0)