Skip to content

Commit 86dabc4

Browse files
Switch from error to warning for inconsistencies in TUDataset Fingerprint. (#203)
* Warning instead of error for inconsistency * Move force consistency to TODO * Add tests and fix graph indicator * Apply suggestions from code review Co-authored-by: Carlo Lucibello <[email protected]> Co-authored-by: Carlo Lucibello <[email protected]>
1 parent 756e76d commit 86dabc4

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

src/datasets/graphs/tudataset.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ function TUDataset(name; dir=nothing)
6363
source, target = st[:,1], st[:,2]
6464

6565
graph_indicator = readdlm(joinpath(d, "$(name)_graph_indicator.txt"), Int) |> vec
66-
@assert all(sort(unique(graph_indicator)) .== 1:length(unique(graph_indicator)))
66+
if !all(sort(unique(graph_indicator)) .== 1:length(unique(graph_indicator)))
67+
@warn "Some graph indicators are not present in graph_indicator.txt. Ordering of graph and graph labels may not be consistent. Base.getindex might produce unexpected behavior for unaltered data."
68+
end
6769

6870
num_nodes = length(graph_indicator)
6971
num_edges = length(source)
@@ -91,6 +93,8 @@ function TUDataset(name; dir=nothing)
9193
readdlm(joinpath(d, "$(name)_graph_attributes.txt"), ',', Float32)' |> collect :
9294
nothing
9395

96+
# TODO: maybe introduce consistency in graph labels and attributes if possible
97+
9498
# We need this two vectors sorted for efficiency in tudataset_getgraph(full_dataset, i)
9599
@assert issorted(graph_indicator)
96100
if !issorted(source)
@@ -115,7 +119,7 @@ function TUDataset(name; dir=nothing)
115119
edge_attributes,
116120
graph_attributes)
117121

118-
graphs = [tudataset_getgraph(full_dataset, i) for i in 1:num_graphs]
122+
graphs = [tudataset_getgraph(full_dataset, i) for i in sort(unique(graph_indicator))]
119123
graph_data = (; features = graph_attributes, targets = graph_labels) |> clean_nt
120124
metadata = Dict{String, Any}("name" => name)
121125
return TUDataset(name, metadata, graphs, graph_data, num_nodes, num_edges, num_graphs)

test/datasets/graphs_no_ci.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,37 @@ end
313313
@test size(g.edge_data.features) == (4, g.num_edges)
314314
end
315315

316+
@testset "TUDataset - Fingerprint" begin
317+
@test_warn "" TUDataset("Fingerprint")
318+
data = TUDataset("Fingerprint")
319+
320+
@test data.num_nodes == 15167
321+
@test data.num_edges == 24756
322+
@test data.num_graphs == 2149
323+
324+
@test data.num_nodes == sum(g->g.num_nodes, data.graphs)
325+
@test data.num_edges == sum(g->g.num_edges, data.graphs)
326+
@test data.num_edges == sum(g->length(g.edge_index[1]), data.graphs)
327+
@test data.num_edges == sum(g->length(g.edge_index[2]), data.graphs)
328+
@test data.num_graphs == length(data) == length(data.graphs)
329+
330+
i = rand(1:length(data))
331+
@test_throws DimensionMismatch data[i]
332+
g = data.graphs[i]
333+
@test g isa Graph
334+
@test all(1 .<= g.edge_index[1] .<= g.num_nodes)
335+
@test all(1 .<= g.edge_index[2] .<= g.num_nodes)
336+
337+
# graph data
338+
@test size(data.graph_data.targets) == (2800, )
339+
340+
# node data
341+
@test size(g.node_data.features) == (2, g.num_nodes)
342+
343+
# edge data
344+
@test size(g.edge_data.features) == (2, g.num_edges)
345+
end
346+
316347
@testset "OrganicMaterialsDB" begin
317348
data = OrganicMaterialsDB(split=:train)
318349
@test length(data) == 10000

0 commit comments

Comments
 (0)