@@ -27,17 +27,30 @@ def test_read_data(self) -> None:
2727 msg = "The output should be an instance of torch_geometric.data.Data." ,
2828 )
2929
30- assert (
31- data .edge_index .shape [0 ] == 2
32- ), f"Expected edge_index to have shape [2, num_edges], but got shape { data .edge_index .shape } "
30+ self .assertEqual (
31+ data .edge_index .shape [0 ],
32+ 2 ,
33+ msg = f"Expected edge_index to have shape [2, num_edges], but got shape { data .edge_index .shape } " ,
34+ )
35+
36+ self .assertEqual (
37+ data .edge_index .shape [1 ],
38+ data .edge_attr .shape [0 ],
39+ msg = f"Mismatch between number of edges in edge_index ({ data .edge_index .shape [1 ]} ) and edge_attr ({ data .edge_attr .shape [0 ]} )" ,
40+ )
3341
34- assert (
35- data .edge_index .shape [1 ] == data .edge_attr .shape [0 ]
36- ), f"Mismatch between number of edges in edge_index ({ data .edge_index .shape [1 ]} ) and edge_attr ({ data .edge_attr .shape [0 ]} )"
42+ self .assertEqual (
43+ len (set (data .edge_index [0 ].tolist ())),
44+ data .x .shape [0 ],
45+ msg = f"Number of unique source nodes in edge_index ({ len (set (data .edge_index [0 ].tolist ()))} ) does not match number of nodes in x ({ data .x .shape [0 ]} )" ,
46+ )
3747
38- assert (
39- len (set (data .edge_index [0 ].tolist ())) == data .x .shape [0 ]
40- ), f"Number of unique source nodes in edge_index ({ len (set (data .edge_index [0 ].tolist ()))} ) does not match number of nodes in x ({ data .x .shape [0 ]} )"
48+ # Check for duplicates by checking if the rows are the same (direction matters)
49+ _ , counts = torch .unique (data .edge_index .t (), dim = 0 , return_counts = True )
50+ self .assertFalse (
51+ torch .any (counts > 1 ),
52+ msg = "There are duplicates of directed edge in edge_index" ,
53+ )
4154
4255 expected_data : GeomData = self .molecule_graph .get_aspirin_graph ()
4356 self .assertTrue (
0 commit comments