Skip to content

Commit 5d4c174

Browse files
committed
add test for duplicate directed edges
1 parent 1a8dcb6 commit 5d4c174

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

tests/unit/readers/testGraphPropertyReader.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,30 @@ def test_read_data(self) -> None:
2727
msg="The output should be an instance of torch_geometric.data.Data.",
2828
)
2929

30-
assert (
31-
data.edge_index.shape[0] == 2
32-
), f"Expected edge_index to have shape [2, num_edges], but got shape {data.edge_index.shape}"
30+
self.assertEqual(
31+
data.edge_index.shape[0],
32+
2,
33+
msg=f"Expected edge_index to have shape [2, num_edges], but got shape {data.edge_index.shape}",
34+
)
35+
36+
self.assertEqual(
37+
data.edge_index.shape[1],
38+
data.edge_attr.shape[0],
39+
msg=f"Mismatch between number of edges in edge_index ({data.edge_index.shape[1]}) and edge_attr ({data.edge_attr.shape[0]})",
40+
)
3341

34-
assert (
35-
data.edge_index.shape[1] == data.edge_attr.shape[0]
36-
), f"Mismatch between number of edges in edge_index ({data.edge_index.shape[1]}) and edge_attr ({data.edge_attr.shape[0]})"
42+
self.assertEqual(
43+
len(set(data.edge_index[0].tolist())),
44+
data.x.shape[0],
45+
msg=f"Number of unique source nodes in edge_index ({len(set(data.edge_index[0].tolist()))}) does not match number of nodes in x ({data.x.shape[0]})",
46+
)
3747

38-
assert (
39-
len(set(data.edge_index[0].tolist())) == data.x.shape[0]
40-
), f"Number of unique source nodes in edge_index ({len(set(data.edge_index[0].tolist()))}) does not match number of nodes in x ({data.x.shape[0]})"
48+
# Check for duplicates by checking if the rows are the same (direction matters)
49+
_, counts = torch.unique(data.edge_index.t(), dim=0, return_counts=True)
50+
self.assertFalse(
51+
torch.any(counts > 1),
52+
msg="There are duplicates of directed edge in edge_index",
53+
)
4154

4255
expected_data: GeomData = self.molecule_graph.get_aspirin_graph()
4356
self.assertTrue(

0 commit comments

Comments
 (0)